From c329f1c5020c3df814be0a1e98cd740c5a4e4621 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Mon, 23 Dec 2019 23:07:25 -0800 Subject: [PATCH 1/4] Support for complex type GEMM and GEMV --- tensorflow/stream_executor/rocm/rocm_blas.cc | 493 ++++++++++++------- tensorflow/stream_executor/rocm/rocm_blas.h | 15 +- 2 files changed, 317 insertions(+), 191 deletions(-) diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index a5a588bbbde..f8afc7f1c5f 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -114,10 +114,10 @@ namespace wrap { __macro(rocblas_zdotc) */ \ __macro(rocblas_sscal) \ __macro(rocblas_dscal) \ - /*__macro(rocblas_cscal) \ + __macro(rocblas_cscal) \ __macro(rocblas_csscal) \ __macro(rocblas_zscal) \ - __macro(rocblas_zdscal) */ \ + __macro(rocblas_zdscal) \ __macro(rocblas_saxpy) \ __macro(rocblas_daxpy) \ /*__macro(rocblas_caxpy) \ @@ -158,9 +158,9 @@ namespace wrap { __macro(rocblas_drotmg) */ \ __macro(rocblas_sgemv) \ __macro(rocblas_dgemv) \ - /*__macro(rocblas_cgemv) \ + __macro(rocblas_cgemv) \ __macro(rocblas_zgemv) \ - __macro(rocblas_sgbmv) \ + /* __macro(rocblas_sgbmv) \ __macro(rocblas_dgbmv) \ __macro(rocblas_cgbmv) \ __macro(rocblas_zgbmv) \ @@ -231,9 +231,9 @@ namespace wrap { __macro(rocblas_sgemm) \ __macro(rocblas_dgemm) \ __macro(rocblas_hgemm) \ - /*__macro(rocblas_cgemm) \ + __macro(rocblas_cgemm) \ __macro(rocblas_zgemm) \ - __macro(rocblas_ssyrk) \ + /* __macro(rocblas_ssyrk) \ __macro(rocblas_dsyrk) \ __macro(rocblas_csyrk) \ __macro(rocblas_zsyrk) \ @@ -285,12 +285,35 @@ STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched) STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched) // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched) STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched) +STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched) +STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched) // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched) // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched) ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP) } // namespace wrap + +template +const typename RocBlasTypeConversionHelper::mapped_type* complex_cast(const DeviceMemory &a) +{ + return reinterpret_cast::mapped_type*>(GpuMemory(a)); +} +template +const typename RocBlasTypeConversionHelper::mapped_type* complex_cast(const T &a) +{ + return reinterpret_cast::mapped_type*>(&a); +} +template +typename RocBlasTypeConversionHelper::mapped_type* complex_cast(DeviceMemory *a) +{ + return reinterpret_cast::mapped_type*>(GpuMemoryMutable(a)); +} + +static void blas_log(const char* c) +{ +} + static string ToString(rocblas_status status) { switch (status) { case rocblas_status_success: @@ -436,7 +459,7 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the ASUM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -444,13 +467,14 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the ASUM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { + blas_log("DoBlasAxpy"); return DoBlasInternal(wrap::rocblas_saxpy, stream, true /* = pointer_mode_host */, elem_count, &alpha, GpuMemory(x), incx, GpuMemoryMutable(y), incy); @@ -459,6 +483,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { + blas_log("DoBlasAxpy"); return DoBlasInternal(wrap::rocblas_daxpy, stream, true /* = pointer_mode_host */, elem_count, &alpha, GpuMemory(x), incx, GpuMemoryMutable(y), incy); @@ -469,7 +494,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the AXPY operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -478,7 +503,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the AXPY operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -502,7 +527,7 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the COPY operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -510,7 +535,7 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the COPY operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -518,6 +543,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *result) { + blas_log("DoBlasDot"); return DoBlasInternal( wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count, GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); @@ -527,6 +553,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *result) { + blas_log("DoBlasDot"); return DoBlasInternal( wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count, GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); @@ -537,7 +564,7 @@ bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, const DeviceMemory> &y, int incy, DeviceMemory> *result) { LOG(ERROR) << "rocBLAS does not currently support the DOT operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -546,7 +573,7 @@ bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, const DeviceMemory> &y, int incy, DeviceMemory> *result) { LOG(ERROR) << "rocBLAS does not currently support the DOT operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -555,7 +582,7 @@ bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, const DeviceMemory> &y, int incy, DeviceMemory> *result) { LOG(ERROR) << "rocBLAS does not currently support the DOT operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -564,7 +591,7 @@ bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, const DeviceMemory> &y, int incy, DeviceMemory> *result) { LOG(ERROR) << "rocBLAS does not currently support the DOT operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -588,7 +615,7 @@ bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -596,7 +623,7 @@ bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -604,7 +631,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy, float c, float s) { LOG(ERROR) << "rocBLAS does not currently support the ROT operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -613,7 +640,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *y, int incy, double c, double s) { LOG(ERROR) << "rocBLAS does not currently support the ROT operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -622,7 +649,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory> *y, int incy, float c, float s) { LOG(ERROR) << "rocBLAS does not currently support the ROT operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -631,7 +658,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory> *y, int incy, double c, double s) { LOG(ERROR) << "rocBLAS does not currently support the ROT operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -639,7 +666,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory *a, DeviceMemory *b, DeviceMemory *c, DeviceMemory *s) { LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -647,7 +674,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory *a, DeviceMemory *b, DeviceMemory *c, DeviceMemory *s) { LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -656,7 +683,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory> *a, DeviceMemory *c, DeviceMemory> *s) { LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -665,7 +692,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory> *a, DeviceMemory *c, DeviceMemory> *s) { LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -674,7 +701,7 @@ bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *y, int incy, const DeviceMemory ¶m) { LOG(ERROR) << "rocBLAS does not currently support the ROTM operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -683,7 +710,7 @@ bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *y, int incy, const DeviceMemory ¶m) { LOG(ERROR) << "rocBLAS does not currently support the ROTM operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -692,7 +719,7 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory *d1, const DeviceMemory &y1, DeviceMemory *param) { LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -701,12 +728,13 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory *d1, const DeviceMemory &y1, DeviceMemory *param) { LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, DeviceMemory *x, int incx) { + blas_log("DoBlasScal"); return DoBlasInternal(wrap::rocblas_sscal, stream, true /* = pointer_mode_host */, elem_count, &alpha, GpuMemoryMutable(x), incx); @@ -721,32 +749,32 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, DeviceMemory> *x, int incx) { - LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " - << "for the \"complex\" dataype"; - return false; + return DoBlasInternal(wrap::rocblas_csscal, stream, + true /* = pointer_mode_host */, elem_count, &alpha, + complex_cast(x), incx); } bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, DeviceMemory> *x, int incx) { - LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " - << "for the \"complex\" dataype"; - return false; + return DoBlasInternal(wrap::rocblas_zdscal, stream, + true /* = pointer_mode_host */, elem_count, &alpha, + complex_cast(x), incx); } bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, std::complex alpha, DeviceMemory> *x, int incx) { - LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " - << "for the \"complex\" dataype"; - return false; + return DoBlasInternal(wrap::rocblas_cscal, stream, + true /* = pointer_mode_host */, elem_count, complex_cast(alpha), + complex_cast(x), incx); } bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, std::complex alpha, DeviceMemory> *x, int incx) { - LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " - << "for the \"complex\" dataype"; - return false; + return DoBlasInternal(wrap::rocblas_zscal, stream, + true /* = pointer_mode_host */, elem_count, complex_cast(alpha), + complex_cast(x), incx); } bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, @@ -769,7 +797,7 @@ bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SWAP operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -777,7 +805,7 @@ bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SWAP operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -801,7 +829,7 @@ bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the AMAX operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -809,7 +837,7 @@ bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the AMAX operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -833,7 +861,7 @@ bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the AMIN operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -841,7 +869,7 @@ bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory *result) { LOG(ERROR) << "rocBLAS does not currently support the AMIN operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -851,7 +879,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -861,7 +889,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -873,7 +901,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -885,7 +913,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -893,6 +921,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { + blas_log("DoBlasGemv"); return DoBlasInternal( wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), @@ -903,6 +932,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { + blas_log("DoBlasGemv"); return DoBlasInternal( wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), @@ -915,9 +945,13 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) { - LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " - << "for the \"complex\" dataype"; - return false; + blas_log("DoBlasGemv"); + return DoBlasInternal( + wrap::rocblas_cgemv, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(trans), m, n, complex_cast(alpha), + complex_cast(a), lda, + complex_cast(x), incx, + complex_cast(beta), complex_cast(y), incy); } bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, @@ -926,9 +960,13 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, const DeviceMemory> &x, int incx, std::complex beta, DeviceMemory> *y, int incy) { - LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " - << "for the \"complex\" dataype"; - return false; + blas_log("DoBlasGemv\n"); + return DoBlasInternal( + wrap::rocblas_zgemv, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(trans), m, n, complex_cast(alpha), + complex_cast(a), lda, + complex_cast(x), incx, + complex_cast(beta), complex_cast(y), incy); } bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, @@ -955,7 +993,7 @@ bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the GER operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -965,7 +1003,7 @@ bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the GER operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -975,7 +1013,7 @@ bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the GERU operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -985,7 +1023,7 @@ bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the GERU operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -996,7 +1034,7 @@ bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the HBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1007,7 +1045,7 @@ bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the HBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1018,7 +1056,7 @@ bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the HEMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1029,7 +1067,7 @@ bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the HEMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1038,7 +1076,7 @@ bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &x, int incx, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the HER operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1047,7 +1085,7 @@ bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &x, int incx, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the HER operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1057,7 +1095,7 @@ bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the HER2 operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1067,7 +1105,7 @@ bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the HER2 operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1078,7 +1116,7 @@ bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the HPMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1089,7 +1127,7 @@ bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, std::complex beta, DeviceMemory> *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the HPMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1098,7 +1136,7 @@ bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &x, int incx, DeviceMemory> *ap) { LOG(ERROR) << "rocBLAS does not currently support the HPR operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1107,7 +1145,7 @@ bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &x, int incx, DeviceMemory> *ap) { LOG(ERROR) << "rocBLAS does not currently support the HPR operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1117,7 +1155,7 @@ bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *ap) { LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1127,7 +1165,7 @@ bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory> &y, int incy, DeviceMemory> *ap) { LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1136,7 +1174,7 @@ bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1146,7 +1184,7 @@ bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1155,7 +1193,7 @@ bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SPMV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1164,7 +1202,7 @@ bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SPMV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1172,7 +1210,7 @@ bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, DeviceMemory *ap) { LOG(ERROR) << "rocBLAS does not currently support the SPR operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1180,7 +1218,7 @@ bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, DeviceMemory *ap) { LOG(ERROR) << "rocBLAS does not currently support the SPR operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1189,7 +1227,7 @@ bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *ap) { LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1198,7 +1236,7 @@ bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *ap) { LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1207,7 +1245,7 @@ bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SYMV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1216,7 +1254,7 @@ bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { LOG(ERROR) << "rocBLAS does not currently support the SYMV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1243,7 +1281,7 @@ bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1252,7 +1290,7 @@ bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) { LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1261,7 +1299,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1270,7 +1308,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1280,7 +1318,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1290,7 +1328,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1299,7 +1337,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1308,7 +1346,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1318,7 +1356,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1328,7 +1366,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1337,7 +1375,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &ap, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1346,7 +1384,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &ap, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1355,7 +1393,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1364,7 +1402,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1373,7 +1411,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &ap, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1382,7 +1420,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &ap, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1391,7 +1429,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1400,7 +1438,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1409,7 +1447,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1418,7 +1456,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1427,7 +1465,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1436,7 +1474,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1445,7 +1483,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1454,7 +1492,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1463,7 +1501,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1472,7 +1510,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1481,6 +1519,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) { + blas_log("DoBlasGemm"); VLOG(1) << absl::StreamFormat( "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " @@ -1526,6 +1565,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, float alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) { + blas_log("DoBlasGemm"); VLOG(1) << absl::StreamFormat( "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " @@ -1565,6 +1605,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, double alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) { + blas_log("DoBlasGemm"); return DoBlasInternal( wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, @@ -1578,9 +1619,11 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) { - LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " - << "for the \"complex\" dataype"; - return false; + blas_log("DoBlasGemm"); + return DoBlasInternal( + wrap::rocblas_cgemm, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, complex_cast(alpha), + complex_cast(a), lda, complex_cast(b), ldb, complex_cast(beta), complex_cast(c), ldc); } bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, @@ -1590,9 +1633,11 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const DeviceMemory> &b, int ldb, std::complex beta, DeviceMemory> *c, int ldc) { - LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " - << "for the \"complex\" dataype"; - return false; + blas_log("DoBlasGemm"); + return DoBlasInternal( + wrap::rocblas_zgemm, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, complex_cast(alpha), + complex_cast(a), lda, complex_cast(b), ldb, complex_cast(beta), complex_cast(c), ldc); } bool ROCMBlas::DoBlasGemvWithProfiling( @@ -1739,7 +1784,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( blas::ProfileResult *output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " - << "for the \"int8\" dataype"; + << "for the \"int8\" datatype"; return false; } @@ -1753,7 +1798,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " - << "for the \"half\" dataype"; + << "for the \"half\" datatype"; return false; } @@ -1766,7 +1811,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -1779,7 +1824,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -1794,7 +1839,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( blas::ProfileResult *output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -1809,10 +1854,63 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( blas::ProfileResult *output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } + +// This copies from source memory: raw_ptrs[i] to target memory: +// device_memory_ptr at the interval of matrix_byte_size, or vice versa. +// The below algorithm tries to minimize the number of memcpy by consolidating neighboring +// memcpy into a single request +template +port::Status ReorganizeMemory(Stream* stream, + DeviceMemory *device_memory, + const std::vector &raw_ptrs, + int batch_count, uint64_t batch_stride, + bool gather) +{ + assert(batch_count > 0); + char *device_memory_ptr = static_cast(device_memory->opaque()); + char* src_ptr = reinterpret_cast(raw_ptrs[0]); + char* dst_ptr = device_memory_ptr; + size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T); + uint64_t cur_stride_size = matrix_byte_size; + + for (int i = 1; i < batch_count; ++i) { + if (reinterpret_cast(raw_ptrs[i]) == src_ptr + cur_stride_size) { + cur_stride_size += matrix_byte_size; + } else { + DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size); + DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size); + bool a_status = + gather + ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok() + : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok(); + if (!a_status) { + return port::Status( + port::error::INTERNAL, + "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); + } + src_ptr = reinterpret_cast(raw_ptrs[i]); + dst_ptr = device_memory_ptr + i * matrix_byte_size; + cur_stride_size = matrix_byte_size; + } + } + + DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size); + DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size); + bool a_status = + gather + ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok() + : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok(); + if (!a_status) + return port::Status( + port::error::INTERNAL, + "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); + return port::Status::OK(); +} + template port::Status ROCMBlas::AllocateStridedBuffer( const std::vector::mapped_type *> @@ -1822,7 +1920,9 @@ port::Status ROCMBlas::AllocateStridedBuffer( std::unique_ptr::mapped_type>> *temp_memory, DeviceMemory::mapped_type> - *device_memory) { + *device_memory, + bool copy_data, + bool& reallocated) { assert(device_memory != nullptr); using MAPPED_T = typename RocBlasTypeConversionHelper::mapped_type; @@ -1843,6 +1943,7 @@ port::Status ROCMBlas::AllocateStridedBuffer( if (!needs_allocate_strided) { *device_memory = DeviceMemory( DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size)); + reallocated = false; return port::Status::OK(); } @@ -1859,22 +1960,14 @@ port::Status ROCMBlas::AllocateStridedBuffer( DeviceMemory(*(*temp_memory)->mutable_device_memory()); } - for (int i = 0; i < batch_count; ++i) { - char *device_memory_ptr = static_cast(device_memory->opaque()); - DeviceMemoryBase src_mem = DeviceMemoryBase(raw_ptrs[i], matrix_byte_size); - DeviceMemoryBase target_mem = DeviceMemoryBase( - device_memory_ptr + i * matrix_byte_size, matrix_byte_size); - bool a_status = - stream->ThenMemcpy(&target_mem, src_mem, matrix_byte_size).ok(); - if (!a_status) { - return port::Status( - port::error::INTERNAL, - "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); - } - } - return port::Status::OK(); + reallocated = true; + + if(copy_data) + return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count, batch_stride, true); + return port::Status::OK(); } + template port::Status ROCMBlas::DoBlasGemmBatchedInternal( FuncT rocblas_func, Stream *stream, blas::Transpose transa, @@ -1896,8 +1989,7 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( if (ROCMBlasTranspose(transa) == rocblas_operation_none) { assert(lda >= m); batch_stride_a = lda * k; - } else { - assert(lda >= k); + } else { assert(lda >= k); batch_stride_a = lda * m; } @@ -1925,9 +2017,10 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( DeviceMemory a; // Make sure the temporary memory are in-scope before the function returns std::unique_ptr> a_temp; + bool reallocated_a, reallocated_b, reallocated_c; port::Status a_allocation_status = AllocateStridedBuffer(a_raw_ptrs, batch_count, batch_stride_a, - scratch_allocator, stream, &a_temp, &a); + scratch_allocator, stream, &a_temp, &a, true, reallocated_a); if (a_allocation_status != port::Status::OK()) { return a_allocation_status; } @@ -1936,7 +2029,7 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( std::unique_ptr> b_temp; port::Status b_allocation_status = AllocateStridedBuffer(b_raw_ptrs, batch_count, batch_stride_b, - scratch_allocator, stream, &b_temp, &b); + scratch_allocator, stream, &b_temp, &b, true, reallocated_b); if (b_allocation_status != port::Status::OK()) { return b_allocation_status; } @@ -1945,7 +2038,7 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( std::unique_ptr> c_temp; port::Status c_allocation_status = AllocateStridedBuffer(c_raw_ptrs, batch_count, batch_stride_c, - scratch_allocator, stream, &c_temp, &c); + scratch_allocator, stream, &c_temp, &c, true, reallocated_c); // can disable copy if beta=0 if (c_allocation_status != port::Status::OK()) { return c_allocation_status; } @@ -1953,19 +2046,19 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( MAPPED_T *alpha_ptr = reinterpret_cast(&alpha); MAPPED_T *beta_ptr = reinterpret_cast(&beta); - bool ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */, - ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), - m, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, - batch_stride_a, GpuMemory(b), ldb, batch_stride_b, - GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, - batch_stride_c, batch_count); - - if (ok) { - return port::Status::OK(); - } else { + bool ok; + ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), + m, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, + batch_stride_a, GpuMemory(b), ldb, batch_stride_b, + GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, + batch_stride_c, batch_count); + if(!ok) return port::Status(port::error::INTERNAL, - "failed BLAS call, see log for details"); - } + "failed BLAS call, see log for details"); + if(reallocated_c) + return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c, false); + return port::Status::OK(); } bool ROCMBlas::DoBlasGemmBatched( @@ -1974,7 +2067,8 @@ bool ROCMBlas::DoBlasGemmBatched( const port::ArraySlice *> &a, int lda, const port::ArraySlice *> &b, int ldb, float beta, const port::ArraySlice *> &c, int ldc, - int batch_count, ScratchAllocator *scratch_allocator) { + int batch_count, ScratchAllocator *scratch_allocator) { + blas_log("DoBlasGemmBatched"); const Eigen::half alpha_half(alpha); const Eigen::half beta_half(beta); @@ -1996,6 +2090,7 @@ bool ROCMBlas::DoBlasGemmBatched( const port::ArraySlice *> &b_array, int ldb, float beta, const port::ArraySlice *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + blas_log("DoBlasGemmBatched"); port::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, @@ -2013,6 +2108,7 @@ bool ROCMBlas::DoBlasGemmBatched( const port::ArraySlice *> &b_array, int ldb, double beta, const port::ArraySlice *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + blas_log("DoBlasGemmBatched"); port::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, @@ -2032,11 +2128,18 @@ bool ROCMBlas::DoBlasGemmBatched( int ldb, std::complex beta, const port::ArraySlice> *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { - LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " - << "for the \"complex\" dataype"; - return false; + blas_log("DoBlasGemmBatched"); + port::Status status = DoBlasGemmBatchedInternal( + wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, + alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, + scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + return status.ok(); } + bool ROCMBlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, @@ -2046,9 +2149,15 @@ bool ROCMBlas::DoBlasGemmBatched( int ldb, std::complex beta, const port::ArraySlice> *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { - LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " - << "for the \"complex\" dataype"; - return false; + blas_log("DoBlasGemmBatched"); + port::Status status = DoBlasGemmBatchedInternal( + wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, + alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, + scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + return status.ok(); } bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, @@ -2059,7 +2168,7 @@ bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the HEMM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2071,7 +2180,7 @@ bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the HEMM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2082,7 +2191,7 @@ bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, float beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the HERK operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2093,7 +2202,7 @@ bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, double beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the HERK operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2105,7 +2214,7 @@ bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, float beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the HER2K operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2117,7 +2226,7 @@ bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, double beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the HER2K operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2127,7 +2236,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -2137,7 +2246,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -2149,7 +2258,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2161,7 +2270,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2170,7 +2279,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, float alpha, const DeviceMemory &a, int lda, float beta, DeviceMemory *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -2179,7 +2288,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, double alpha, const DeviceMemory &a, int lda, double beta, DeviceMemory *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -2190,7 +2299,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2201,7 +2310,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2211,7 +2320,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -2221,7 +2330,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -2233,7 +2342,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2245,7 +2354,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, std::complex beta, DeviceMemory> *c, int ldc) { LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2255,7 +2364,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " - << "for the \"float\" dataype"; + << "for the \"float\" datatype"; return false; } @@ -2265,7 +2374,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " - << "for the \"double\" dataype"; + << "for the \"double\" datatype"; return false; } @@ -2276,7 +2385,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) { LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2287,7 +2396,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) { LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2296,6 +2405,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, blas::Diagonal diag, uint64 m, uint64 n, float alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { + blas_log("DoBlasTrsm"); return DoBlasInternal( wrap::rocblas_strsm, stream, true /* = pointer_mode_host */, ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), @@ -2308,6 +2418,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, blas::Diagonal diag, uint64 m, uint64 n, double alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { + blas_log("DoBlasTrsm"); return DoBlasInternal( wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */, ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), @@ -2322,7 +2433,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) { LOG(ERROR) << "rocBLAS does not currently support the TRSM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } @@ -2333,15 +2444,17 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) { LOG(ERROR) << "rocBLAS does not currently support the TRSM operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } + bool ROCMBlas::DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, float beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) { + blas_log("DoBlasGemmStridedBatched"); const Eigen::half alpha_half(alpha); const Eigen::half beta_half(beta); @@ -2363,6 +2476,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched( int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, float beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) { + blas_log("DoBlasGemmStridedBatched"); return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream, false, /* pointer_mode_host */ ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, @@ -2376,6 +2490,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched( int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, double beta, DeviceMemory *c, int ldc, int64 stride_c, int batch_count) { + blas_log("DoBlasGemmStridedBatched"); return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream, false, /* pointer_mode_host */ ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, @@ -2392,7 +2507,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched( int64 stride_c, int batch_count) { LOG(ERROR) << "rocBLAS does not currently support the " "DoBlasGemmStridedBatched operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } bool ROCMBlas::DoBlasGemmStridedBatched( @@ -2404,7 +2519,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched( int64 stride_c, int batch_count) { LOG(ERROR) << "rocBLAS does not currently support the " "DoBlasGemmStridedBatched operation " - << "for the \"complex\" dataype"; + << "for the \"complex\" datatype"; return false; } diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h index 1b73a356b88..6bf30049686 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.h +++ b/tensorflow/stream_executor/rocm/rocm_blas.h @@ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper { using mapped_type = rocblas_half; }; +template <> +struct RocBlasTypeConversionHelper > { + using mapped_type = rocblas_float_complex; +}; + +template <> +struct RocBlasTypeConversionHelper > { + using mapped_type = rocblas_double_complex; +}; + // Opaque and unique identifier for the rocBLAS plugin. extern const PluginId kRocBlasPlugin; @@ -110,7 +120,7 @@ class ROCMBlas : public blas::BlasSupport { /*err_on_failure=*/false, args...); } - // A helper allocation funciton to convert raw pointers memory layout to + // A helper allocation function to convert raw pointers memory layout to // strided flavor template port::Status AllocateStridedBuffer( @@ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport { std::unique_ptr::mapped_type>> *temp_memory, DeviceMemory::mapped_type> - *device_memory); + *device_memory, bool copy_data, + bool& reallocated); // A helper function to implement DoBlasGemmBatched interfaces for generic // types. From 543db6fc6713ed9ba19cf798a92f4bd2f4ad9ba2 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Mon, 23 Dec 2019 23:08:44 -0800 Subject: [PATCH 2/4] Enabling tests affected by support of complex GEMM and GEMV --- tensorflow/python/kernel_tests/BUILD | 1 - .../kernel_tests/batch_matmul_op_test.py | 5 +- tensorflow/python/kernel_tests/eig_op_test.py | 5 +- .../python/kernel_tests/init_ops_test.py | 7 --- .../linalg/linear_operator_adjoint_test.py | 4 +- .../linalg/linear_operator_circulant_test.py | 20 ------- tensorflow/python/kernel_tests/lu_op_test.py | 56 ++++++++----------- .../python/kernel_tests/matmul_op_test.py | 5 +- .../matrix_exponential_op_test.py | 4 -- .../kernel_tests/matrix_inverse_op_test.py | 38 ++++++------- .../kernel_tests/matrix_logarithm_op_test.py | 8 --- .../matrix_square_root_op_test.py | 17 +++--- .../kernel_tests/self_adjoint_eig_op_test.py | 5 +- tensorflow/python/kernel_tests/signal/BUILD | 1 - tensorflow/python/kernel_tests/svd_op_test.py | 7 +-- .../python/kernel_tests/tensordot_op_test.py | 5 +- .../python/ops/special_math_ops_test.py | 8 +-- 17 files changed, 56 insertions(+), 140 deletions(-) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 6ea17b4fa5a..37d406e50d1 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3405,7 +3405,6 @@ tf_py_test( data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"], shard_count = 20, tags = [ - "no_rocm", # flaky test "no_windows", ], deps = [ diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py index 55eca193d64..b68eaa123c5 100644 --- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py @@ -262,10 +262,7 @@ class BatchMatMulBenchmark(test.Benchmark): if __name__ == "__main__": - dtypes_to_test = [np.float16, np.float32, np.float64, np.int32] - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - dtypes_to_test += [np.complex64, np.complex128] + dtypes_to_test = [np.float16, np.float32, np.float64, np.int32, np.complex64, np.complex128] for dtype_ in dtypes_to_test: for adjoint_a_ in False, True: for adjoint_b_ in False, True: diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py index ffc61b7bcfe..74607c66dc2 100644 --- a/tensorflow/python/kernel_tests/eig_op_test.py +++ b/tensorflow/python/kernel_tests/eig_op_test.py @@ -183,10 +183,7 @@ def _GetEigTest(dtype_, shape_, compute_v_): if __name__ == "__main__": - dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] + dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128] for compute_v in True, False: for dtype in dtypes_to_test: for size in 1, 2, 5, 10: diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 3822b4b89fc..4b9681afd2c 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -746,13 +746,6 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): else: shape = [4, 16, 16, 16, 64] convolution = convolutional.conv3d - - if test.is_built_with_rocm(): - # This subtest triggers a known bug in ROCm runtime code - # The bug has been fixed and will be available in ROCm 2.7 - # Re-enable this test once ROCm 2.7 is released - continue - inputs = random_ops.random_normal(shape, dtype=dtype) inputs_2norm = linalg_ops.norm(inputs) outputs = convolution( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py index c03203f02e5..409ab20985d 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py @@ -141,8 +141,6 @@ class LinearOperatorAdjointTest( full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) def test_matmul_adjoint_complex_operator(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) @@ -201,7 +199,7 @@ class LinearOperatorAdjointTest( def test_solve_adjoint_complex_operator(self): if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") + self.skipTest("ROCm does not support BLAS solve operations for complex types") matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + 1j * linear_operator_test_util.random_tril_matrix( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index 810c47ba1e8..b3c8c2d20ff 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -357,11 +357,6 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( self.evaluate(operator.assert_non_singular()) def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): - - if test.is_built_with_rocm(): - # ROCm does not yet support BLAS operations with complex types. - self.skipTest("ROCm does not support BLAS operations for complex types") - spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64) operator = linalg.LinearOperatorCirculant(spectrum) with self.cached_session(): @@ -665,11 +660,6 @@ class LinearOperatorCirculant3DTest(test.TestCase): yield sess def test_real_spectrum_gives_self_adjoint_operator(self): - - if test.is_built_with_rocm(): - # ROCm does not yet support BLAS operations with complext types - self.skipTest("ROCm does not support BLAS operations for complex types") - with self.cached_session(): # This is a real and hermitian spectrum. spectrum = linear_operator_test_util.random_normal( @@ -686,11 +676,6 @@ class LinearOperatorCirculant3DTest(test.TestCase): self.assertAllClose(matrix, matrix_h) def test_defining_operator_using_real_convolution_kernel(self): - - if test.is_built_with_rocm(): - # ROCm does not yet support BLAS operations with complext types - self.skipTest("ROCm does not support BLAS operations for complex types") - with self.cached_session(): convolution_kernel = linear_operator_test_util.random_normal( shape=(2, 2, 3, 5), dtype=dtypes.float32) @@ -709,11 +694,6 @@ class LinearOperatorCirculant3DTest(test.TestCase): np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5) def test_defining_spd_operator_by_taking_real_part(self): - - if test.is_built_with_rocm(): - # ROCm does not yet support BLAS operations with complext types - self.skipTest("ROCm does not support BLAS operations for complex types") - with self.cached_session(): # Necessary for fft_kernel_label_map # S is real and positive. s = linear_operator_test_util.random_uniform( diff --git a/tensorflow/python/kernel_tests/lu_op_test.py b/tensorflow/python/kernel_tests/lu_op_test.py index 875a3768602..1c0280c3ce6 100644 --- a/tensorflow/python/kernel_tests/lu_op_test.py +++ b/tensorflow/python/kernel_tests/lu_op_test.py @@ -130,14 +130,12 @@ class LuOpTest(test.TestCase): for output_idx_type in (dtypes.int32, dtypes.int64): self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - for dtype in (np.complex64, np.complex128): - for output_idx_type in (dtypes.int32, dtypes.int64): - complex_data = np.tril(1j * data, -1).astype(dtype) - complex_data += np.triu(-1j * data, 1).astype(dtype) - complex_data += data - self._verifyLu(complex_data, output_idx_type=output_idx_type) + for dtype in (np.complex64, np.complex128): + for output_idx_type in (dtypes.int32, dtypes.int64): + complex_data = np.tril(1j * data, -1).astype(dtype) + complex_data += np.triu(-1j * data, 1).astype(dtype) + complex_data += data + self._verifyLu(complex_data, output_idx_type=output_idx_type) def testPivoting(self): # This matrix triggers partial pivoting because the first diagonal entry @@ -152,17 +150,15 @@ class LuOpTest(test.TestCase): # Make sure p_val is not the identity permutation. self.assertNotAllClose(np.arange(3), p_val) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - for dtype in (np.complex64, np.complex128): - complex_data = np.tril(1j * data, -1).astype(dtype) - complex_data += np.triu(-1j * data, 1).astype(dtype) - complex_data += data - self._verifyLu(complex_data) - _, p = linalg_ops.lu(data) - p_val = self.evaluate([p]) - # Make sure p_val is not the identity permutation. - self.assertNotAllClose(np.arange(3), p_val) + for dtype in (np.complex64, np.complex128): + complex_data = np.tril(1j * data, -1).astype(dtype) + complex_data += np.triu(-1j * data, 1).astype(dtype) + complex_data += data + self._verifyLu(complex_data) + _, p = linalg_ops.lu(data) + p_val = self.evaluate([p]) + # Make sure p_val is not the identity permutation. + self.assertNotAllClose(np.arange(3), p_val) def testInvalidMatrix(self): # LU factorization gives an error when the input is singular. @@ -195,13 +191,11 @@ class LuOpTest(test.TestCase): matrices = np.random.rand(batch_size, 5, 5) self._verifyLu(matrices) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - # Generate random complex valued matrices. - np.random.seed(52) - matrices = np.random.rand(batch_size, 5, - 5) + 1j * np.random.rand(batch_size, 5, 5) - self._verifyLu(matrices) + # Generate random complex valued matrices. + np.random.seed(52) + matrices = np.random.rand(batch_size, 5, + 5) + 1j * np.random.rand(batch_size, 5, 5) + self._verifyLu(matrices) def testLargeMatrix(self): # Generate random matrices. @@ -210,12 +204,10 @@ class LuOpTest(test.TestCase): data = np.random.rand(n, n) self._verifyLu(data) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - # Generate random complex valued matrices. - np.random.seed(129) - data = np.random.rand(n, n) + 1j * np.random.rand(n, n) - self._verifyLu(data) + # Generate random complex valued matrices. + np.random.seed(129) + data = np.random.rand(n, n) + 1j * np.random.rand(n, n) + self._verifyLu(data) @test_util.run_v1_only("b/120545219") def testEmpty(self): diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 9f84946397e..dedfa58c3ed 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -226,10 +226,7 @@ class MatMulInfixOperatorTest(test_lib.TestCase): if __name__ == "__main__": sizes = [1, 3, 5] trans_options = [[False, False], [True, False], [False, True]] - dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64] - if not test_lib.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - dtypes_to_test += [np.complex64, np.complex128] + dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64, np.complex128] # TF2 does not support placeholders under eager so we skip it for use_static_shape in set([True, tf2.enabled()]): for dtype in dtypes_to_test: diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py index ed47e8980d9..007db5d2076 100644 --- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py @@ -91,8 +91,6 @@ class ExponentialOpTest(test.TestCase): @test_util.run_deprecated_v1 def testNonsymmetricComplex(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") matrix1 = np.array([[1., 2.], [3., 4.]]) matrix2 = np.array([[1., 3.], [3., 5.]]) matrix1 = matrix1.astype(np.complex64) @@ -114,8 +112,6 @@ class ExponentialOpTest(test.TestCase): self._verifyExponentialReal(self._makeBatch(matrix1, matrix2)) def testSymmetricPositiveDefiniteComplex(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") matrix1 = np.array([[2., 1.], [1., 2.]]) matrix2 = np.array([[3., -1.], [-1., 3.]]) matrix1 = matrix1.astype(np.complex64) diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py index 56a242c0234..244e95eefa2 100644 --- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py @@ -74,17 +74,14 @@ class InverseOpTest(test.TestCase): self._verifyInverseReal(matrix2) # A multidimensional batch of 2x2 matrices self._verifyInverseReal(self._makeBatch(matrix1, matrix2)) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - # Complex - matrix1 = matrix1.astype(np.complex64) - matrix1 += 1j * matrix1 - matrix2 = matrix2.astype(np.complex64) - matrix2 += 1j * matrix2 - self._verifyInverseComplex(matrix1) - self._verifyInverseComplex(matrix2) - # Complex batch - self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) + matrix1 = matrix1.astype(np.complex64) + matrix1 += 1j * matrix1 + matrix2 = matrix2.astype(np.complex64) + matrix2 += 1j * matrix2 + self._verifyInverseComplex(matrix1) + self._verifyInverseComplex(matrix2) + # Complex batch + self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) def testSymmetricPositiveDefinite(self): # 2x2 matrices @@ -94,17 +91,14 @@ class InverseOpTest(test.TestCase): self._verifyInverseReal(matrix2) # A multidimensional batch of 2x2 matrices self._verifyInverseReal(self._makeBatch(matrix1, matrix2)) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - # Complex - matrix1 = matrix1.astype(np.complex64) - matrix1 += 1j * matrix1 - matrix2 = matrix2.astype(np.complex64) - matrix2 += 1j * matrix2 - self._verifyInverseComplex(matrix1) - self._verifyInverseComplex(matrix2) - # Complex batch - self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) + matrix1 = matrix1.astype(np.complex64) + matrix1 += 1j * matrix1 + matrix2 = matrix2.astype(np.complex64) + matrix2 += 1j * matrix2 + self._verifyInverseComplex(matrix1) + self._verifyInverseComplex(matrix2) + # Complex batch + self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) @test_util.deprecated_graph_mode_only def testNonSquareMatrix(self): diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py index 0f5da8b27a4..6d9ba6b66c0 100644 --- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py @@ -59,8 +59,6 @@ class LogarithmOpTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testNonsymmetric(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") # 2x2 matrices matrix1 = np.array([[1., 2.], [3., 4.]]) matrix2 = np.array([[1., 3.], [3., 5.]]) @@ -75,8 +73,6 @@ class LogarithmOpTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testSymmetricPositiveDefinite(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") # 2x2 matrices matrix1 = np.array([[2., 1.], [1., 2.]]) matrix2 = np.array([[3., -1.], [-1., 3.]]) @@ -111,8 +107,6 @@ class LogarithmOpTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testRandomSmallAndLargeComplex64(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") np.random.seed(42) for batch_dims in [(), (1,), (3,), (2, 2)]: for size in 8, 31, 32: @@ -124,8 +118,6 @@ class LogarithmOpTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testRandomSmallAndLargeComplex128(self): - if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS operations for complex types") np.random.seed(42) for batch_dims in [(), (1,), (3,), (2, 2)]: for size in 8, 31, 32: diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py index 2a761140b0a..c36d83e2530 100644 --- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py @@ -59,16 +59,13 @@ class SquareRootOpTest(test.TestCase): self._verifySquareRootReal(matrix1) self._verifySquareRootReal(matrix2) self._verifySquareRootReal(self._makeBatch(matrix1, matrix2)) - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - # Complex - matrix1 = matrix1.astype(np.complex64) - matrix2 = matrix2.astype(np.complex64) - matrix1 += 1j * matrix1 - matrix2 += 1j * matrix2 - self._verifySquareRootComplex(matrix1) - self._verifySquareRootComplex(matrix2) - self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2)) + matrix1 = matrix1.astype(np.complex64) + matrix2 = matrix2.astype(np.complex64) + matrix1 += 1j * matrix1 + matrix2 += 1j * matrix2 + self._verifySquareRootComplex(matrix1) + self._verifySquareRootComplex(matrix2) + self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2)) def testSymmetricPositiveDefinite(self): matrix1 = np.array([[2., 1.], [1., 2.]]) diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index 0ada446e84b..ef64f7cf61b 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -240,10 +240,7 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_): if __name__ == "__main__": - dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] + dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128] for compute_v in True, False: for dtype in dtypes_to_test: for size in 1, 2, 5, 10: diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD index 230b35ccf02..08e65d1fb7c 100644 --- a/tensorflow/python/kernel_tests/signal/BUILD +++ b/tensorflow/python/kernel_tests/signal/BUILD @@ -123,7 +123,6 @@ cuda_py_tests( srcs = ["spectral_ops_test.py"], python_version = "PY3", tags = [ - "no_rocm", "nomac", ], deps = [ diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py index bbcab12a163..9da9f3d6717 100644 --- a/tensorflow/python/kernel_tests/svd_op_test.py +++ b/tensorflow/python/kernel_tests/svd_op_test.py @@ -370,10 +370,7 @@ class SVDBenchmark(test.Benchmark): if __name__ == "__main__": - dtypes_to_test = [np.float32, np.float64] - if not test.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - dtypes_to_test += [np.complex64, np.complex128] + dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128] for compute_uv in False, True: for full_matrices in False, True: for dtype in dtypes_to_test: @@ -392,7 +389,7 @@ if __name__ == "__main__": for compute_uv in False, True: for full_matrices in False, True: dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] * - (not compute_uv) * (not test.is_built_with_rocm())) + (not compute_uv)) for dtype in dtypes: mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)] if not full_matrices or not compute_uv: diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 635a76323f6..3663a91281b 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -221,10 +221,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): if __name__ == "__main__": - dtypes_to_test = [np.float16, np.float32, np.float64] - if not test_lib.is_built_with_rocm(): - # ROCm does not support BLAS operations for complex types - dtypes_to_test += [np.complex64, np.complex128] + dtypes_to_test = [np.float16, np.float32, np.float64, np.complex64, np.complex128] for dtype in dtypes_to_test: for rank_a in 1, 2, 4, 5: for rank_b in 1, 2, 4, 5: diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 77136adc5b4..ae77db4653e 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -338,13 +338,7 @@ class EinsumTest(test.TestCase): self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) def test_dtypes(self): - dtypes = [] - if test.is_built_with_rocm(): - # This test triggers the BLAS op calls on the GPU - # ROCm does not support BLAS operations for complex types - dtypes = [np.float64, np.float32] - else: - dtypes = [np.float64, np.float32, np.complex64, np.complex128] + dtypes = [np.float64, np.float32, np.complex64, np.complex128] for dtype in dtypes: self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype) self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype) From 0f0aa375122dea501a237a5f8462bfde31d03a7d Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Wed, 15 Jan 2020 13:15:04 -0800 Subject: [PATCH 3/4] Linter errors --- tensorflow/python/kernel_tests/batch_matmul_op_test.py | 3 ++- tensorflow/python/kernel_tests/eig_op_test.py | 3 ++- .../python/kernel_tests/linalg/linear_operator_adjoint_test.py | 3 ++- tensorflow/python/kernel_tests/matmul_op_test.py | 3 ++- tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py | 3 ++- tensorflow/python/kernel_tests/tensordot_op_test.py | 3 ++- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py index b68eaa123c5..dab4116ab9d 100644 --- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py @@ -262,7 +262,8 @@ class BatchMatMulBenchmark(test.Benchmark): if __name__ == "__main__": - dtypes_to_test = [np.float16, np.float32, np.float64, np.int32, np.complex64, np.complex128] + dtypes_to_test = [np.float16, np.float32, np.float64, np.int32, + np.complex64, np.complex128] for dtype_ in dtypes_to_test: for adjoint_a_ in False, True: for adjoint_b_ in False, True: diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py index 74607c66dc2..4cfbcd21b49 100644 --- a/tensorflow/python/kernel_tests/eig_op_test.py +++ b/tensorflow/python/kernel_tests/eig_op_test.py @@ -183,7 +183,8 @@ def _GetEigTest(dtype_, shape_, compute_v_): if __name__ == "__main__": - dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128] + dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, + dtypes_lib.complex64, dtypes_lib.complex128] for compute_v in True, False: for dtype in dtypes_to_test: for size in 1, 2, 5, 10: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py index 409ab20985d..ad419ced5d1 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py @@ -199,7 +199,8 @@ class LinearOperatorAdjointTest( def test_solve_adjoint_complex_operator(self): if test.is_built_with_rocm(): - self.skipTest("ROCm does not support BLAS solve operations for complex types") + self.skipTest("ROCm does not support BLAS solve operations" + " for complex types") matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + 1j * linear_operator_test_util.random_tril_matrix( diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index dedfa58c3ed..cf562e094ed 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -226,7 +226,8 @@ class MatMulInfixOperatorTest(test_lib.TestCase): if __name__ == "__main__": sizes = [1, 3, 5] trans_options = [[False, False], [True, False], [False, True]] - dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64, np.complex128] + dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64, + np.complex64, np.complex128] # TF2 does not support placeholders under eager so we skip it for use_static_shape in set([True, tf2.enabled()]): for dtype in dtypes_to_test: diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index ef64f7cf61b..73609a3c1cf 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -240,7 +240,8 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_): if __name__ == "__main__": - dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128] + dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, + dtypes_lib.complex64, dtypes_lib.complex128] for compute_v in True, False: for dtype in dtypes_to_test: for size in 1, 2, 5, 10: diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 3663a91281b..b63e3df6919 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -221,7 +221,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): if __name__ == "__main__": - dtypes_to_test = [np.float16, np.float32, np.float64, np.complex64, np.complex128] + dtypes_to_test = [np.float16, np.float32, np.float64, + np.complex64, np.complex128] for dtype in dtypes_to_test: for rank_a in 1, 2, 4, 5: for rank_b in 1, 2, 4, 5: From 461598528c06a89dc6d99b43bb455320fb1e3c26 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Thu, 26 Mar 2020 15:42:51 -0700 Subject: [PATCH 4/4] Fixing linter errors --- tensorflow/python/kernel_tests/batch_matmul_op_test.py | 2 +- tensorflow/python/kernel_tests/eig_op_test.py | 2 +- .../python/kernel_tests/linalg/linear_operator_adjoint_test.py | 2 +- tensorflow/python/kernel_tests/matmul_op_test.py | 2 +- tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py | 2 +- tensorflow/python/kernel_tests/tensordot_op_test.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py index dab4116ab9d..6be9c28bad3 100644 --- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py @@ -263,7 +263,7 @@ class BatchMatMulBenchmark(test.Benchmark): if __name__ == "__main__": dtypes_to_test = [np.float16, np.float32, np.float64, np.int32, - np.complex64, np.complex128] + np.complex64, np.complex128] for dtype_ in dtypes_to_test: for adjoint_a_ in False, True: for adjoint_b_ in False, True: diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py index 4cfbcd21b49..d75326f3119 100644 --- a/tensorflow/python/kernel_tests/eig_op_test.py +++ b/tensorflow/python/kernel_tests/eig_op_test.py @@ -184,7 +184,7 @@ def _GetEigTest(dtype_, shape_, compute_v_): if __name__ == "__main__": dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, - dtypes_lib.complex64, dtypes_lib.complex128] + dtypes_lib.complex64, dtypes_lib.complex128] for compute_v in True, False: for dtype in dtypes_to_test: for size in 1, 2, 5, 10: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py index ad419ced5d1..5619f1cd38a 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py @@ -200,7 +200,7 @@ class LinearOperatorAdjointTest( def test_solve_adjoint_complex_operator(self): if test.is_built_with_rocm(): self.skipTest("ROCm does not support BLAS solve operations" - " for complex types") + " for complex types") matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + 1j * linear_operator_test_util.random_tril_matrix( diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index cf562e094ed..23b2cc98728 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -227,7 +227,7 @@ if __name__ == "__main__": sizes = [1, 3, 5] trans_options = [[False, False], [True, False], [False, True]] dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64, - np.complex64, np.complex128] + np.complex64, np.complex128] # TF2 does not support placeholders under eager so we skip it for use_static_shape in set([True, tf2.enabled()]): for dtype in dtypes_to_test: diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index 73609a3c1cf..cc98f3cd785 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -241,7 +241,7 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_): if __name__ == "__main__": dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64, - dtypes_lib.complex64, dtypes_lib.complex128] + dtypes_lib.complex64, dtypes_lib.complex128] for compute_v in True, False: for dtype in dtypes_to_test: for size in 1, 2, 5, 10: diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index b63e3df6919..b4b623b9140 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -222,7 +222,7 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): if __name__ == "__main__": dtypes_to_test = [np.float16, np.float32, np.float64, - np.complex64, np.complex128] + np.complex64, np.complex128] for dtype in dtypes_to_test: for rank_a in 1, 2, 4, 5: for rank_b in 1, 2, 4, 5: