From d75510457b1dcda7e84c86acfa38bdec4f3177e5 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Tue, 2 Jul 2019 03:49:46 +0000 Subject: [PATCH] cosmetic / formating changes --- tensorflow/stream_executor/rocm/rocm_blas.cc | 158 +++++++++--------- tensorflow/stream_executor/rocm/rocm_blas.h | 18 +- .../rocm/rocm_driver_wrapper.h | 6 +- tensorflow/stream_executor/rocm/rocm_fft.cc | 72 ++++---- tensorflow/stream_executor/rocm/rocm_rng.cc | 3 +- 5 files changed, 132 insertions(+), 125 deletions(-) diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index 52c654617f8..fd444a4cd6b 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -53,14 +53,14 @@ namespace wrap { #ifdef PLATFORM_GOOGLE #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ struct WrapperShim__##__name { \ - static const char *kName; \ + static const char* kName; \ template \ - rocblas_status operator()(GpuExecutor *parent, Args... args) { \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ gpu::ScopedActivateExecutorContext sac{parent}; \ return ::__name(args...); \ } \ } __name; \ - const char *WrapperShim__##__name::kName = #__name; + const char* WrapperShim__##__name::kName = #__name; #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ STREAM_EXECUTOR_ROCBLAS_WRAP(__name) @@ -69,14 +69,14 @@ namespace wrap { #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ struct DynLoadShim__##__name { \ - static const char *kName; \ + static const char* kName; \ using FuncPtrT = std::add_pointer::type; \ - static void *GetDsoHandle() { \ + static void* GetDsoHandle() { \ auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \ return s.ValueOrDie(); \ } \ static FuncPtrT LoadOrDie() { \ - void *f; \ + void* f; \ auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ kName, &f); \ CHECK(s.ok()) << "could not find " << kName \ @@ -88,12 +88,12 @@ namespace wrap { return f; \ } \ template \ - rocblas_status operator()(GpuExecutor *parent, Args... args) { \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ gpu::ScopedActivateExecutorContext sac{parent}; \ return DynLoad()(args...); \ } \ } __name; \ - const char *DynLoadShim__##__name::kName = #__name; + const char* DynLoadShim__##__name::kName = #__name; #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ STREAM_EXECUTOR_ROCBLAS_WRAP(__name) @@ -322,7 +322,7 @@ bool ROCMBlas::Init() { return true; } -ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent) +ROCMBlas::ROCMBlas(gpu::GpuExecutor* parent) : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} ROCMBlas::~ROCMBlas() { @@ -1476,11 +1476,11 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, return false; } -bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, +bool ROCMBlas::DoBlasGemm(Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, - float alpha, const DeviceMemory &a, - int lda, const DeviceMemory &b, int ldb, - float beta, DeviceMemory *c, int ldc) { + float alpha, const DeviceMemory& a, + int lda, const DeviceMemory& b, int ldb, + float beta, DeviceMemory* c, int ldc) { VLOG(1) << absl::StreamFormat( "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " @@ -1514,11 +1514,11 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, return DoBlasInternal( wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(&alpha_half), - reinterpret_cast(GpuMemory(a)), lda, - reinterpret_cast(GpuMemory(b)), ldb, - reinterpret_cast(&beta_half), - reinterpret_cast(GpuMemoryMutable(c)), ldc); + reinterpret_cast(&alpha_half), + reinterpret_cast(GpuMemory(a)), lda, + reinterpret_cast(GpuMemory(b)), ldb, + reinterpret_cast(&beta_half), + reinterpret_cast(GpuMemoryMutable(c)), ldc); } bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, @@ -1731,12 +1731,12 @@ bool ROCMBlas::GetBlasGemmAlgorithms( } bool ROCMBlas::DoBlasGemmWithAlgorithm( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, const HostOrDeviceScalar &alpha, - const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, - const HostOrDeviceScalar &beta, DeviceMemory *c, int ldc, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar& alpha, + const DeviceMemory& a, int lda, const DeviceMemory& b, int ldb, + const HostOrDeviceScalar& beta, DeviceMemory* c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, - blas::ProfileResult *output_profile_result) { + blas::ProfileResult* output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " << "for the \"int8\" dataype"; @@ -1744,13 +1744,13 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( } bool ROCMBlas::DoBlasGemmWithAlgorithm( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, const HostOrDeviceScalar &alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - const HostOrDeviceScalar &beta, DeviceMemory *c, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar& alpha, + const DeviceMemory& a, int lda, + const DeviceMemory& b, int ldb, + const HostOrDeviceScalar& beta, DeviceMemory* c, int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " << "for the \"half\" dataype"; @@ -1758,12 +1758,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( } bool ROCMBlas::DoBlasGemmWithAlgorithm( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, const HostOrDeviceScalar &alpha, - const DeviceMemory &a, int lda, const DeviceMemory &b, - int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar& alpha, + const DeviceMemory& a, int lda, const DeviceMemory& b, + int ldb, const HostOrDeviceScalar& beta, DeviceMemory* c, int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " << "for the \"float\" dataype"; @@ -1771,12 +1771,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( } bool ROCMBlas::DoBlasGemmWithAlgorithm( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, const HostOrDeviceScalar &alpha, - const DeviceMemory &a, int lda, const DeviceMemory &b, - int ldb, const HostOrDeviceScalar &beta, DeviceMemory *c, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar& alpha, + const DeviceMemory& a, int lda, const DeviceMemory& b, + int ldb, const HostOrDeviceScalar& beta, DeviceMemory* c, int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) { LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " << "for the \"double\" dataype"; @@ -1815,14 +1815,14 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( template port::Status ROCMBlas::AllocateStridedBuffer( - const std::vector::mapped_type *> - &raw_ptrs, - int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator, - Stream *stream, + const std::vector::mapped_type*>& + raw_ptrs, + int batch_count, uint64_t batch_stride, ScratchAllocator* scratch_allocator, + Stream* stream, std::unique_ptr::mapped_type>> *temp_memory, - DeviceMemory::mapped_type> - *device_memory) { + typename RocBlasTypeConversionHelper::mapped_type>>* temp_memory, + DeviceMemory::mapped_type>* + device_memory) { assert(device_memory != nullptr); using MAPPED_T = typename RocBlasTypeConversionHelper::mapped_type; @@ -1969,12 +1969,12 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( } bool ROCMBlas::DoBlasGemmBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, - const port::ArraySlice *> &a, int lda, - const port::ArraySlice *> &b, int ldb, float beta, - const port::ArraySlice *> &c, int ldc, - int batch_count, ScratchAllocator *scratch_allocator) { + const port::ArraySlice*>& a, int lda, + const port::ArraySlice*>& b, int ldb, float beta, + const port::ArraySlice*>& c, int ldc, + int batch_count, ScratchAllocator* scratch_allocator) { const Eigen::half alpha_half(alpha); const Eigen::half beta_half(beta); @@ -2299,7 +2299,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, return DoBlasInternal( wrap::rocblas_strsm, stream, true /* = pointer_mode_host */, ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, const_cast(GpuMemory(a)), + ROCMBlasDiagonal(diag), m, n, &alpha, const_cast(GpuMemory(a)), lda, GpuMemoryMutable(b), ldb); } @@ -2311,7 +2311,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, return DoBlasInternal( wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */, ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), - ROCMBlasDiagonal(diag), m, n, &alpha, const_cast(GpuMemory(a)), + ROCMBlasDiagonal(diag), m, n, &alpha, const_cast(GpuMemory(a)), lda, GpuMemoryMutable(b), ldb); } @@ -2349,19 +2349,19 @@ bool ROCMBlas::DoBlasGemmStridedBatched( wrap::rocblas_hgemm_strided_batched, stream, false, /* pointer_mode_host */ ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, - reinterpret_cast(&alpha_half), - reinterpret_cast(GpuMemory(a)), lda, stride_a, - reinterpret_cast(GpuMemory(b)), ldb, stride_b, - reinterpret_cast(&beta_half), - reinterpret_cast(GpuMemoryMutable(c)), ldc, stride_c, + reinterpret_cast(&alpha_half), + reinterpret_cast(GpuMemory(a)), lda, stride_a, + reinterpret_cast(GpuMemory(b)), ldb, stride_b, + reinterpret_cast(&beta_half), + reinterpret_cast(GpuMemoryMutable(c)), ldc, stride_c, batch_count); } bool ROCMBlas::DoBlasGemmStridedBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, - int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, - float beta, DeviceMemory *c, int ldc, int64 stride_c, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory& a, int lda, + int64 stride_a, const DeviceMemory& b, int ldb, int64 stride_b, + float beta, DeviceMemory* c, int ldc, int64 stride_c, int batch_count) { return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream, false, /* pointer_mode_host */ @@ -2371,10 +2371,10 @@ bool ROCMBlas::DoBlasGemmStridedBatched( stride_c, batch_count); } bool ROCMBlas::DoBlasGemmStridedBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, - int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, - double beta, DeviceMemory *c, int ldc, int64 stride_c, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory& a, int lda, + int64 stride_a, const DeviceMemory& b, int ldb, int64 stride_b, + double beta, DeviceMemory* c, int ldc, int64 stride_c, int batch_count) { return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream, false, /* pointer_mode_host */ @@ -2384,11 +2384,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched( stride_c, batch_count); } bool ROCMBlas::DoBlasGemmStridedBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, int64 stride_a, - const DeviceMemory> &b, int ldb, int64 stride_b, - std::complex beta, DeviceMemory> *c, int ldc, + const DeviceMemory>& a, int lda, int64 stride_a, + const DeviceMemory>& b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory>* c, int ldc, int64 stride_c, int batch_count) { LOG(ERROR) << "rocBLAS does not currently support the " "DoBlasGemmStridedBatched operation " @@ -2396,11 +2396,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched( return false; } bool ROCMBlas::DoBlasGemmStridedBatched( - Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, int64 stride_a, - const DeviceMemory> &b, int ldb, int64 stride_b, - std::complex beta, DeviceMemory> *c, int ldc, + const DeviceMemory>& a, int lda, int64 stride_a, + const DeviceMemory>& b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory>* c, int ldc, int64 stride_c, int batch_count) { LOG(ERROR) << "rocBLAS does not currently support the " "DoBlasGemmStridedBatched operation " @@ -2418,10 +2418,10 @@ void initialize_rocblas() { PluginRegistry::Instance() ->RegisterFactory( rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS", - [](internal::StreamExecutorInterface *parent) - -> blas::BlasSupport * { - gpu::GpuExecutor *rocm_executor = - dynamic_cast(parent); + [](internal::StreamExecutorInterface* parent) + -> blas::BlasSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast(parent); if (rocm_executor == nullptr) { LOG(ERROR) << "Attempting to initialize an instance of the " @@ -2430,7 +2430,7 @@ void initialize_rocblas() { return nullptr; } - gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor); + gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor); if (!blas->Init()) { // Note: Init() will log a more specific error. delete blas; diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h index 1b73a356b88..87e7d6717f3 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.h +++ b/tensorflow/stream_executor/rocm/rocm_blas.h @@ -62,7 +62,7 @@ class GpuExecutor; // Thread-safe post-initialization. class ROCMBlas : public blas::BlasSupport { public: - explicit ROCMBlas(GpuExecutor *parent); + explicit ROCMBlas(GpuExecutor* parent); // Allocates a rocBLAS handle. bool Init(); @@ -98,7 +98,7 @@ class ROCMBlas : public blas::BlasSupport { // Convenience functions that call DoBlasInternalImpl with different values // for err_on_failure. template - bool DoBlasInternal(FuncT rocblas_func, Stream *stream, + bool DoBlasInternal(FuncT rocblas_func, Stream* stream, bool pointer_mode_host, Args... args) { return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, /*err_on_failure=*/true, args...); @@ -114,14 +114,14 @@ class ROCMBlas : public blas::BlasSupport { // strided flavor template port::Status AllocateStridedBuffer( - const std::vector::mapped_type *> - &raw_ptrs, + const std::vector::mapped_type*>& + raw_ptrs, int batch_count, uint64_t batch_stride, - ScratchAllocator *scratch_allocator, Stream *stream, + ScratchAllocator* scratch_allocator, Stream* stream, std::unique_ptr::mapped_type>> *temp_memory, - DeviceMemory::mapped_type> - *device_memory); + typename RocBlasTypeConversionHelper::mapped_type>>* temp_memory, + DeviceMemory::mapped_type>* + device_memory); // A helper function to implement DoBlasGemmBatched interfaces for generic // types. @@ -185,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport { // GpuExecutor which instantiated this ROCMBlas. // Immutable post-initialization. - GpuExecutor *parent_; + GpuExecutor* parent_; // rocBLAS library handle on the device. rocblas_handle blas_ GUARDED_BY(mu_); diff --git a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h index ba803edaafb..bc5b6a87888 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h +++ b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h @@ -27,10 +27,6 @@ limitations under the License. #include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/port.h" -#if defined(TENSORFLOW_USE_ROCM) - -#endif - namespace tensorflow { namespace wrap { #ifdef PLATFORM_GOOGLE @@ -83,8 +79,8 @@ namespace wrap { __macro(hipDeviceTotalMem) \ __macro(hipDriverGetVersion) \ __macro(hipEventCreateWithFlags) \ - __macro(hipEventElapsedTime) \ __macro(hipEventDestroy) \ + __macro(hipEventElapsedTime) \ __macro(hipEventQuery) \ __macro(hipEventRecord) \ __macro(hipEventSynchronize) \ diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc index 2af973309c0..8ad5cc19ded 100644 --- a/tensorflow/stream_executor/rocm/rocm_fft.cc +++ b/tensorflow/stream_executor/rocm/rocm_fft.cc @@ -48,7 +48,7 @@ namespace wrap { #define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ struct WrapperShim__##__name { \ template \ - hipfftResult operator()(GpuExecutor *parent, Args... args) { \ + hipfftResult operator()(GpuExecutor* parent, Args... args) { \ gpu::ScopedActivateExecutorContext sac{parent}; \ return ::__name(args...); \ } \ @@ -86,21 +86,33 @@ namespace wrap { #endif -#define ROCFFT_ROUTINE_EACH(__macro) \ - __macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \ - __macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \ - __macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \ - __macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \ - __macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \ - __macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \ - __macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \ - __macro(hipfftMakePlanMany) \ - __macro(hipfftExecD2Z) \ - __macro(hipfftExecZ2D) \ - __macro(hipfftExecC2C) \ - __macro(hipfftExecC2R) \ - __macro(hipfftExecZ2Z) \ - __macro(hipfftExecR2C) +// clang-format off +#define ROCFFT_ROUTINE_EACH(__macro) \ + __macro(hipfftDestroy) \ + __macro(hipfftSetStream) \ + __macro(hipfftPlan1d) \ + __macro(hipfftPlan2d) \ + __macro(hipfftPlan3d) \ + __macro(hipfftPlanMany) \ + __macro(hipfftCreate) \ + __macro(hipfftSetAutoAllocation) \ + __macro(hipfftSetWorkArea) \ + __macro(hipfftGetSize1d) \ + __macro(hipfftMakePlan1d) \ + __macro(hipfftGetSize2d) \ + __macro(hipfftMakePlan2d) \ + __macro(hipfftGetSize3d) \ + __macro(hipfftMakePlan3d) \ + __macro(hipfftGetSizeMany) \ + __macro(hipfftMakePlanMany) \ + __macro(hipfftExecD2Z) \ + __macro(hipfftExecZ2D) \ + __macro(hipfftExecC2C) \ + __macro(hipfftExecC2R) \ + __macro(hipfftExecZ2Z) \ + __macro(hipfftExecR2C) \ + +// clang-format on ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP) @@ -515,7 +527,7 @@ bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec, } auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(), - GpuComplex(const_cast(GpuMemory(input))), + GpuComplex(const_cast(GpuMemory(input))), GpuComplex(GpuMemoryMutable(output))); if (ret != HIPFFT_SUCCESS) { @@ -542,7 +554,7 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, } auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(), - GpuComplex(const_cast(GpuMemory(input))), + GpuComplex(const_cast(GpuMemory(input))), GpuComplex(GpuMemoryMutable(output)), rocm_fft_plan->GetFftDirection()); @@ -556,21 +568,21 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, #define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \ __fft_type3) \ - bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory> &input, \ - DeviceMemory> *output) { \ + bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \ + const DeviceMemory>& input, \ + DeviceMemory>* output) { \ return DoFftWithDirectionInternal( \ stream, plan, wrap::hipfftExec##__fft_type1, input, output); \ } \ - bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<__type> &input, \ - DeviceMemory> *output) { \ + bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \ + const DeviceMemory<__type>& input, \ + DeviceMemory>* output) { \ return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \ output); \ } \ - bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory> &input, \ - DeviceMemory<__type> *output) { \ + bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \ + const DeviceMemory>& input, \ + DeviceMemory<__type>* output) { \ return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \ output); \ } @@ -590,9 +602,9 @@ void initialize_rocfft() { port::Status status = PluginRegistry::Instance()->RegisterFactory( rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT", - [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * { - gpu::GpuExecutor *rocm_executor = - dynamic_cast(parent); + [](internal::StreamExecutorInterface* parent) -> fft::FftSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast(parent); if (rocm_executor == nullptr) { LOG(ERROR) << "Attempting to initialize an instance of the rocFFT " diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 38f4f8bb0c6..2492cc0e5d9 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -14,12 +14,11 @@ limitations under the License. ==============================================================================*/ #include "rocm/include/hiprand/hiprand.h" -#include "tensorflow/stream_executor/gpu/gpu_rng.h" - #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/gpu/gpu_activation.h" #include "tensorflow/stream_executor/gpu/gpu_executor.h" #include "tensorflow/stream_executor/gpu/gpu_helpers.h" +#include "tensorflow/stream_executor/gpu/gpu_rng.h" #include "tensorflow/stream_executor/gpu/gpu_stream.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h"