cosmetic / formating changes

This commit is contained in:
Deven Desai 2019-07-02 03:49:46 +00:00
parent e1f2191a31
commit d75510457b
5 changed files with 132 additions and 125 deletions

View File

@ -53,14 +53,14 @@ namespace wrap {
#ifdef PLATFORM_GOOGLE
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
struct WrapperShim__##__name { \
static const char *kName; \
static const char* kName; \
template <typename... Args> \
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name; \
const char *WrapperShim__##__name::kName = #__name;
const char* WrapperShim__##__name::kName = #__name;
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
@ -69,14 +69,14 @@ namespace wrap {
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
static void* GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void *f; \
void* f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
@ -88,12 +88,12 @@ namespace wrap {
return f; \
} \
template <typename... Args> \
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
const char* DynLoadShim__##__name::kName = #__name;
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
@ -322,7 +322,7 @@ bool ROCMBlas::Init() {
return true;
}
ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
ROCMBlas::ROCMBlas(gpu::GpuExecutor* parent)
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
ROCMBlas::~ROCMBlas() {
@ -1476,11 +1476,11 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
return false;
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
bool ROCMBlas::DoBlasGemm(Stream* stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<Eigen::half> &a,
int lda, const DeviceMemory<Eigen::half> &b, int ldb,
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
float alpha, const DeviceMemory<Eigen::half>& a,
int lda, const DeviceMemory<Eigen::half>& b, int ldb,
float beta, DeviceMemory<Eigen::half>* c, int ldc) {
VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
@ -1514,11 +1514,11 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
return DoBlasInternal(
wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
reinterpret_cast<const rocblas_half *>(&alpha_half),
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
reinterpret_cast<const rocblas_half *>(&beta_half),
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
reinterpret_cast<const rocblas_half*>(&alpha_half),
reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda,
reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb,
reinterpret_cast<const rocblas_half*>(&beta_half),
reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc);
}
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
@ -1731,12 +1731,12 @@ bool ROCMBlas::GetBlasGemmAlgorithms(
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<int>& alpha,
const DeviceMemory<int8>& a, int lda, const DeviceMemory<int8>& b, int ldb,
const HostOrDeviceScalar<int>& beta, DeviceMemory<int32>* c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
blas::ProfileResult* output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"int8\" dataype";
@ -1744,13 +1744,13 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb,
const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half>& alpha,
const DeviceMemory<Eigen::half>& a, int lda,
const DeviceMemory<Eigen::half>& b, int ldb,
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"half\" dataype";
@ -1758,12 +1758,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<float>& alpha,
const DeviceMemory<float>& a, int lda, const DeviceMemory<float>& b,
int ldb, const HostOrDeviceScalar<float>& beta, DeviceMemory<float>* c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"float\" dataype";
@ -1771,12 +1771,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
}
bool ROCMBlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const HostOrDeviceScalar<double>& alpha,
const DeviceMemory<double>& a, int lda, const DeviceMemory<double>& b,
int ldb, const HostOrDeviceScalar<double>& beta, DeviceMemory<double>* c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
LOG(ERROR)
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
<< "for the \"double\" dataype";
@ -1815,14 +1815,14 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
template <typename T>
port::Status ROCMBlas::AllocateStridedBuffer(
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
&raw_ptrs,
int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
Stream *stream,
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
raw_ptrs,
int batch_count, uint64_t batch_stride, ScratchAllocator* scratch_allocator,
Stream* stream,
std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory) {
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
device_memory) {
assert(device_memory != nullptr);
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
@ -1969,12 +1969,12 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
}
bool ROCMBlas::DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) {
const port::ArraySlice<DeviceMemory<Eigen::half>*>& a, int lda,
const port::ArraySlice<DeviceMemory<Eigen::half>*>& b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<Eigen::half>*>& c, int ldc,
int batch_count, ScratchAllocator* scratch_allocator) {
const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta);
@ -2299,7 +2299,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
return DoBlasInternal(
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float*>(GpuMemory(a)),
lda, GpuMemoryMutable(b), ldb);
}
@ -2311,7 +2311,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
return DoBlasInternal(
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double*>(GpuMemory(a)),
lda, GpuMemoryMutable(b), ldb);
}
@ -2349,19 +2349,19 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
wrap::rocblas_hgemm_strided_batched, stream,
false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
reinterpret_cast<const rocblas_half *>(&alpha_half),
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda, stride_a,
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb, stride_b,
reinterpret_cast<const rocblas_half *>(&beta_half),
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc, stride_c,
reinterpret_cast<const rocblas_half*>(&alpha_half),
reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda, stride_a,
reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb, stride_b,
reinterpret_cast<const rocblas_half*>(&beta_half),
reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc, stride_c,
batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<float>& a, int lda,
int64 stride_a, const DeviceMemory<float>& b, int ldb, int64 stride_b,
float beta, DeviceMemory<float>* c, int ldc, int64 stride_c,
int batch_count) {
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
false, /* pointer_mode_host */
@ -2371,10 +2371,10 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
stride_c, batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha, const DeviceMemory<double>& a, int lda,
int64 stride_a, const DeviceMemory<double>& b, int ldb, int64 stride_b,
double beta, DeviceMemory<double>* c, int ldc, int64 stride_c,
int batch_count) {
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
false, /* pointer_mode_host */
@ -2384,11 +2384,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
stride_c, batch_count);
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
const DeviceMemory<std::complex<float>>& a, int lda, int64 stride_a,
const DeviceMemory<std::complex<float>>& b, int ldb, int64 stride_b,
std::complex<float> beta, DeviceMemory<std::complex<float>>* c, int ldc,
int64 stride_c, int batch_count) {
LOG(ERROR) << "rocBLAS does not currently support the "
"DoBlasGemmStridedBatched operation "
@ -2396,11 +2396,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
return false;
}
bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
const DeviceMemory<std::complex<double>>& a, int lda, int64 stride_a,
const DeviceMemory<std::complex<double>>& b, int ldb, int64 stride_b,
std::complex<double> beta, DeviceMemory<std::complex<double>>* c, int ldc,
int64 stride_c, int batch_count) {
LOG(ERROR) << "rocBLAS does not currently support the "
"DoBlasGemmStridedBatched operation "
@ -2418,10 +2418,10 @@ void initialize_rocblas() {
PluginRegistry::Instance()
->RegisterFactory<PluginRegistry::BlasFactory>(
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
[](internal::StreamExecutorInterface *parent)
-> blas::BlasSupport * {
gpu::GpuExecutor *rocm_executor =
dynamic_cast<gpu::GpuExecutor *>(parent);
[](internal::StreamExecutorInterface* parent)
-> blas::BlasSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the "
@ -2430,7 +2430,7 @@ void initialize_rocblas() {
return nullptr;
}
gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
if (!blas->Init()) {
// Note: Init() will log a more specific error.
delete blas;

View File

@ -62,7 +62,7 @@ class GpuExecutor;
// Thread-safe post-initialization.
class ROCMBlas : public blas::BlasSupport {
public:
explicit ROCMBlas(GpuExecutor *parent);
explicit ROCMBlas(GpuExecutor* parent);
// Allocates a rocBLAS handle.
bool Init();
@ -98,7 +98,7 @@ class ROCMBlas : public blas::BlasSupport {
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
template <typename FuncT, typename... Args>
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
bool DoBlasInternal(FuncT rocblas_func, Stream* stream,
bool pointer_mode_host, Args... args) {
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
/*err_on_failure=*/true, args...);
@ -114,14 +114,14 @@ class ROCMBlas : public blas::BlasSupport {
// strided flavor
template <typename T>
port::Status AllocateStridedBuffer(
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
&raw_ptrs,
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
raw_ptrs,
int batch_count, uint64_t batch_stride,
ScratchAllocator *scratch_allocator, Stream *stream,
ScratchAllocator* scratch_allocator, Stream* stream,
std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory);
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
device_memory);
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.
@ -185,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport {
// GpuExecutor which instantiated this ROCMBlas.
// Immutable post-initialization.
GpuExecutor *parent_;
GpuExecutor* parent_;
// rocBLAS library handle on the device.
rocblas_handle blas_ GUARDED_BY(mu_);

View File

@ -27,10 +27,6 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/port.h"
#if defined(TENSORFLOW_USE_ROCM)
#endif
namespace tensorflow {
namespace wrap {
#ifdef PLATFORM_GOOGLE
@ -83,8 +79,8 @@ namespace wrap {
__macro(hipDeviceTotalMem) \
__macro(hipDriverGetVersion) \
__macro(hipEventCreateWithFlags) \
__macro(hipEventElapsedTime) \
__macro(hipEventDestroy) \
__macro(hipEventElapsedTime) \
__macro(hipEventQuery) \
__macro(hipEventRecord) \
__macro(hipEventSynchronize) \

View File

@ -48,7 +48,7 @@ namespace wrap {
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
hipfftResult operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
@ -86,21 +86,33 @@ namespace wrap {
#endif
#define ROCFFT_ROUTINE_EACH(__macro) \
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
__macro(hipfftMakePlanMany) \
__macro(hipfftExecD2Z) \
__macro(hipfftExecZ2D) \
__macro(hipfftExecC2C) \
__macro(hipfftExecC2R) \
__macro(hipfftExecZ2Z) \
__macro(hipfftExecR2C)
// clang-format off
#define ROCFFT_ROUTINE_EACH(__macro) \
__macro(hipfftDestroy) \
__macro(hipfftSetStream) \
__macro(hipfftPlan1d) \
__macro(hipfftPlan2d) \
__macro(hipfftPlan3d) \
__macro(hipfftPlanMany) \
__macro(hipfftCreate) \
__macro(hipfftSetAutoAllocation) \
__macro(hipfftSetWorkArea) \
__macro(hipfftGetSize1d) \
__macro(hipfftMakePlan1d) \
__macro(hipfftGetSize2d) \
__macro(hipfftMakePlan2d) \
__macro(hipfftGetSize3d) \
__macro(hipfftMakePlan3d) \
__macro(hipfftGetSizeMany) \
__macro(hipfftMakePlanMany) \
__macro(hipfftExecD2Z) \
__macro(hipfftExecZ2D) \
__macro(hipfftExecC2C) \
__macro(hipfftExecC2R) \
__macro(hipfftExecZ2Z) \
__macro(hipfftExecR2C) \
// clang-format on
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
@ -515,7 +527,7 @@ bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
}
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
GpuComplex(const_cast<InputT*>(GpuMemory(input))),
GpuComplex(GpuMemoryMutable(output)));
if (ret != HIPFFT_SUCCESS) {
@ -542,7 +554,7 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
}
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
GpuComplex(const_cast<InputT*>(GpuMemory(input))),
GpuComplex(GpuMemoryMutable(output)),
rocm_fft_plan->GetFftDirection());
@ -556,21 +568,21 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
__fft_type3) \
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<std::complex<__type>> *output) { \
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
const DeviceMemory<std::complex<__type>>& input, \
DeviceMemory<std::complex<__type>>* output) { \
return DoFftWithDirectionInternal( \
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
} \
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<__type> &input, \
DeviceMemory<std::complex<__type>> *output) { \
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
const DeviceMemory<__type>& input, \
DeviceMemory<std::complex<__type>>* output) { \
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
output); \
} \
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<__type> *output) { \
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
const DeviceMemory<std::complex<__type>>& input, \
DeviceMemory<__type>* output) { \
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
output); \
}
@ -590,9 +602,9 @@ void initialize_rocfft() {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
gpu::GpuExecutor *rocm_executor =
dynamic_cast<gpu::GpuExecutor *>(parent);
[](internal::StreamExecutorInterface* parent) -> fft::FftSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the rocFFT "

View File

@ -14,12 +14,11 @@ limitations under the License.
==============================================================================*/
#include "rocm/include/hiprand/hiprand.h"
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"