cosmetic / formating changes
This commit is contained in:
parent
e1f2191a31
commit
d75510457b
@ -53,14 +53,14 @@ namespace wrap {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
static const char *kName; \
|
||||
static const char* kName; \
|
||||
template <typename... Args> \
|
||||
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
|
||||
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return ::__name(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char *WrapperShim__##__name::kName = #__name;
|
||||
const char* WrapperShim__##__name::kName = #__name;
|
||||
|
||||
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
||||
@ -69,14 +69,14 @@ namespace wrap {
|
||||
|
||||
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char *kName; \
|
||||
static const char* kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void *GetDsoHandle() { \
|
||||
static void* GetDsoHandle() { \
|
||||
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void *f; \
|
||||
void* f; \
|
||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
@ -88,12 +88,12 @@ namespace wrap {
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
|
||||
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char *DynLoadShim__##__name::kName = #__name;
|
||||
const char* DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
||||
@ -322,7 +322,7 @@ bool ROCMBlas::Init() {
|
||||
return true;
|
||||
}
|
||||
|
||||
ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
|
||||
ROCMBlas::ROCMBlas(gpu::GpuExecutor* parent)
|
||||
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
|
||||
|
||||
ROCMBlas::~ROCMBlas() {
|
||||
@ -1476,11 +1476,11 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
bool ROCMBlas::DoBlasGemm(Stream* stream, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
|
||||
float alpha, const DeviceMemory<Eigen::half> &a,
|
||||
int lda, const DeviceMemory<Eigen::half> &b, int ldb,
|
||||
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
|
||||
float alpha, const DeviceMemory<Eigen::half>& a,
|
||||
int lda, const DeviceMemory<Eigen::half>& b, int ldb,
|
||||
float beta, DeviceMemory<Eigen::half>* c, int ldc) {
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
@ -1514,11 +1514,11 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||
reinterpret_cast<const rocblas_half *>(&alpha_half),
|
||||
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
|
||||
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
|
||||
reinterpret_cast<const rocblas_half *>(&beta_half),
|
||||
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
|
||||
reinterpret_cast<const rocblas_half*>(&alpha_half),
|
||||
reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda,
|
||||
reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb,
|
||||
reinterpret_cast<const rocblas_half*>(&beta_half),
|
||||
reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
@ -1731,12 +1731,12 @@ bool ROCMBlas::GetBlasGemmAlgorithms(
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
|
||||
const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
|
||||
const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<int>& alpha,
|
||||
const DeviceMemory<int8>& a, int lda, const DeviceMemory<int8>& b, int ldb,
|
||||
const HostOrDeviceScalar<int>& beta, DeviceMemory<int32>* c, int ldc,
|
||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
LOG(ERROR)
|
||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||
<< "for the \"int8\" dataype";
|
||||
@ -1744,13 +1744,13 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
|
||||
const DeviceMemory<Eigen::half> &a, int lda,
|
||||
const DeviceMemory<Eigen::half> &b, int ldb,
|
||||
const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half>& alpha,
|
||||
const DeviceMemory<Eigen::half>& a, int lda,
|
||||
const DeviceMemory<Eigen::half>& b, int ldb,
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
|
||||
int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
|
||||
LOG(ERROR)
|
||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||
<< "for the \"half\" dataype";
|
||||
@ -1758,12 +1758,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
|
||||
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
|
||||
int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<float>& alpha,
|
||||
const DeviceMemory<float>& a, int lda, const DeviceMemory<float>& b,
|
||||
int ldb, const HostOrDeviceScalar<float>& beta, DeviceMemory<float>* c,
|
||||
int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
|
||||
LOG(ERROR)
|
||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||
<< "for the \"float\" dataype";
|
||||
@ -1771,12 +1771,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
|
||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
|
||||
int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const HostOrDeviceScalar<double>& alpha,
|
||||
const DeviceMemory<double>& a, int lda, const DeviceMemory<double>& b,
|
||||
int ldb, const HostOrDeviceScalar<double>& beta, DeviceMemory<double>* c,
|
||||
int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
|
||||
LOG(ERROR)
|
||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||
<< "for the \"double\" dataype";
|
||||
@ -1815,14 +1815,14 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
|
||||
template <typename T>
|
||||
port::Status ROCMBlas::AllocateStridedBuffer(
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
||||
&raw_ptrs,
|
||||
int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
|
||||
Stream *stream,
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
|
||||
raw_ptrs,
|
||||
int batch_count, uint64_t batch_stride, ScratchAllocator* scratch_allocator,
|
||||
Stream* stream,
|
||||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory) {
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
|
||||
device_memory) {
|
||||
assert(device_memory != nullptr);
|
||||
|
||||
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
|
||||
@ -1969,12 +1969,12 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, float alpha,
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
|
||||
int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half>*>& a, int lda,
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half>*>& b, int ldb, float beta,
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half>*>& c, int ldc,
|
||||
int batch_count, ScratchAllocator* scratch_allocator) {
|
||||
const Eigen::half alpha_half(alpha);
|
||||
const Eigen::half beta_half(beta);
|
||||
|
||||
@ -2299,7 +2299,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
||||
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
|
||||
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float*>(GpuMemory(a)),
|
||||
lda, GpuMemoryMutable(b), ldb);
|
||||
}
|
||||
|
||||
@ -2311,7 +2311,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
||||
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
|
||||
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double*>(GpuMemory(a)),
|
||||
lda, GpuMemoryMutable(b), ldb);
|
||||
}
|
||||
|
||||
@ -2349,19 +2349,19 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
wrap::rocblas_hgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||
reinterpret_cast<const rocblas_half *>(&alpha_half),
|
||||
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda, stride_a,
|
||||
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb, stride_b,
|
||||
reinterpret_cast<const rocblas_half *>(&beta_half),
|
||||
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc, stride_c,
|
||||
reinterpret_cast<const rocblas_half*>(&alpha_half),
|
||||
reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda, stride_a,
|
||||
reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb, stride_b,
|
||||
reinterpret_cast<const rocblas_half*>(&beta_half),
|
||||
reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc, stride_c,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
||||
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
||||
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, float alpha, const DeviceMemory<float>& a, int lda,
|
||||
int64 stride_a, const DeviceMemory<float>& b, int ldb, int64 stride_b,
|
||||
float beta, DeviceMemory<float>* c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
@ -2371,10 +2371,10 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
stride_c, batch_count);
|
||||
}
|
||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
|
||||
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
|
||||
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, double alpha, const DeviceMemory<double>& a, int lda,
|
||||
int64 stride_a, const DeviceMemory<double>& b, int ldb, int64 stride_b,
|
||||
double beta, DeviceMemory<double>* c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
@ -2384,11 +2384,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
stride_c, batch_count);
|
||||
}
|
||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, std::complex<float> alpha,
|
||||
const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
|
||||
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
||||
const DeviceMemory<std::complex<float>>& a, int lda, int64 stride_a,
|
||||
const DeviceMemory<std::complex<float>>& b, int ldb, int64 stride_b,
|
||||
std::complex<float> beta, DeviceMemory<std::complex<float>>* c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
||||
"DoBlasGemmStridedBatched operation "
|
||||
@ -2396,11 +2396,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
return false;
|
||||
}
|
||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, std::complex<double> alpha,
|
||||
const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
|
||||
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
|
||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
||||
const DeviceMemory<std::complex<double>>& a, int lda, int64 stride_a,
|
||||
const DeviceMemory<std::complex<double>>& b, int ldb, int64 stride_b,
|
||||
std::complex<double> beta, DeviceMemory<std::complex<double>>* c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
||||
"DoBlasGemmStridedBatched operation "
|
||||
@ -2418,10 +2418,10 @@ void initialize_rocblas() {
|
||||
PluginRegistry::Instance()
|
||||
->RegisterFactory<PluginRegistry::BlasFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
|
||||
[](internal::StreamExecutorInterface *parent)
|
||||
-> blas::BlasSupport * {
|
||||
gpu::GpuExecutor *rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor *>(parent);
|
||||
[](internal::StreamExecutorInterface* parent)
|
||||
-> blas::BlasSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the "
|
||||
@ -2430,7 +2430,7 @@ void initialize_rocblas() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
|
||||
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
|
||||
if (!blas->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete blas;
|
||||
|
@ -62,7 +62,7 @@ class GpuExecutor;
|
||||
// Thread-safe post-initialization.
|
||||
class ROCMBlas : public blas::BlasSupport {
|
||||
public:
|
||||
explicit ROCMBlas(GpuExecutor *parent);
|
||||
explicit ROCMBlas(GpuExecutor* parent);
|
||||
|
||||
// Allocates a rocBLAS handle.
|
||||
bool Init();
|
||||
@ -98,7 +98,7 @@ class ROCMBlas : public blas::BlasSupport {
|
||||
// Convenience functions that call DoBlasInternalImpl with different values
|
||||
// for err_on_failure.
|
||||
template <typename FuncT, typename... Args>
|
||||
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
|
||||
bool DoBlasInternal(FuncT rocblas_func, Stream* stream,
|
||||
bool pointer_mode_host, Args... args) {
|
||||
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
||||
/*err_on_failure=*/true, args...);
|
||||
@ -114,14 +114,14 @@ class ROCMBlas : public blas::BlasSupport {
|
||||
// strided flavor
|
||||
template <typename T>
|
||||
port::Status AllocateStridedBuffer(
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
||||
&raw_ptrs,
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
|
||||
raw_ptrs,
|
||||
int batch_count, uint64_t batch_stride,
|
||||
ScratchAllocator *scratch_allocator, Stream *stream,
|
||||
ScratchAllocator* scratch_allocator, Stream* stream,
|
||||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory);
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
|
||||
device_memory);
|
||||
|
||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||
// types.
|
||||
@ -185,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport {
|
||||
|
||||
// GpuExecutor which instantiated this ROCMBlas.
|
||||
// Immutable post-initialization.
|
||||
GpuExecutor *parent_;
|
||||
GpuExecutor* parent_;
|
||||
|
||||
// rocBLAS library handle on the device.
|
||||
rocblas_handle blas_ GUARDED_BY(mu_);
|
||||
|
@ -27,10 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace wrap {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
@ -83,8 +79,8 @@ namespace wrap {
|
||||
__macro(hipDeviceTotalMem) \
|
||||
__macro(hipDriverGetVersion) \
|
||||
__macro(hipEventCreateWithFlags) \
|
||||
__macro(hipEventElapsedTime) \
|
||||
__macro(hipEventDestroy) \
|
||||
__macro(hipEventElapsedTime) \
|
||||
__macro(hipEventQuery) \
|
||||
__macro(hipEventRecord) \
|
||||
__macro(hipEventSynchronize) \
|
||||
|
@ -48,7 +48,7 @@ namespace wrap {
|
||||
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
||||
hipfftResult operator()(GpuExecutor* parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return ::__name(args...); \
|
||||
} \
|
||||
@ -86,21 +86,33 @@ namespace wrap {
|
||||
|
||||
#endif
|
||||
|
||||
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
|
||||
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
|
||||
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
|
||||
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
|
||||
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
|
||||
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
|
||||
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
|
||||
__macro(hipfftMakePlanMany) \
|
||||
__macro(hipfftExecD2Z) \
|
||||
__macro(hipfftExecZ2D) \
|
||||
__macro(hipfftExecC2C) \
|
||||
__macro(hipfftExecC2R) \
|
||||
__macro(hipfftExecZ2Z) \
|
||||
__macro(hipfftExecR2C)
|
||||
// clang-format off
|
||||
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||
__macro(hipfftDestroy) \
|
||||
__macro(hipfftSetStream) \
|
||||
__macro(hipfftPlan1d) \
|
||||
__macro(hipfftPlan2d) \
|
||||
__macro(hipfftPlan3d) \
|
||||
__macro(hipfftPlanMany) \
|
||||
__macro(hipfftCreate) \
|
||||
__macro(hipfftSetAutoAllocation) \
|
||||
__macro(hipfftSetWorkArea) \
|
||||
__macro(hipfftGetSize1d) \
|
||||
__macro(hipfftMakePlan1d) \
|
||||
__macro(hipfftGetSize2d) \
|
||||
__macro(hipfftMakePlan2d) \
|
||||
__macro(hipfftGetSize3d) \
|
||||
__macro(hipfftMakePlan3d) \
|
||||
__macro(hipfftGetSizeMany) \
|
||||
__macro(hipfftMakePlanMany) \
|
||||
__macro(hipfftExecD2Z) \
|
||||
__macro(hipfftExecZ2D) \
|
||||
__macro(hipfftExecC2C) \
|
||||
__macro(hipfftExecC2R) \
|
||||
__macro(hipfftExecZ2Z) \
|
||||
__macro(hipfftExecR2C) \
|
||||
|
||||
// clang-format on
|
||||
|
||||
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
|
||||
|
||||
@ -515,7 +527,7 @@ bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
|
||||
}
|
||||
|
||||
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||
GpuComplex(const_cast<InputT*>(GpuMemory(input))),
|
||||
GpuComplex(GpuMemoryMutable(output)));
|
||||
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
@ -542,7 +554,7 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||
}
|
||||
|
||||
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||
GpuComplex(const_cast<InputT*>(GpuMemory(input))),
|
||||
GpuComplex(GpuMemoryMutable(output)),
|
||||
rocm_fft_plan->GetFftDirection());
|
||||
|
||||
@ -556,21 +568,21 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||
|
||||
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
||||
__fft_type3) \
|
||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
|
||||
const DeviceMemory<std::complex<__type>>& input, \
|
||||
DeviceMemory<std::complex<__type>>* output) { \
|
||||
return DoFftWithDirectionInternal( \
|
||||
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
|
||||
} \
|
||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<__type> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
|
||||
const DeviceMemory<__type>& input, \
|
||||
DeviceMemory<std::complex<__type>>* output) { \
|
||||
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
|
||||
output); \
|
||||
} \
|
||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<__type> *output) { \
|
||||
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
|
||||
const DeviceMemory<std::complex<__type>>& input, \
|
||||
DeviceMemory<__type>* output) { \
|
||||
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
|
||||
output); \
|
||||
}
|
||||
@ -590,9 +602,9 @@ void initialize_rocfft() {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
||||
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
|
||||
gpu::GpuExecutor *rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor *>(parent);
|
||||
[](internal::StreamExecutorInterface* parent) -> fft::FftSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the rocFFT "
|
||||
|
@ -14,12 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "rocm/include/hiprand/hiprand.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
|
Loading…
Reference in New Issue
Block a user