cosmetic / formating changes
This commit is contained in:
parent
e1f2191a31
commit
d75510457b
@ -53,14 +53,14 @@ namespace wrap {
|
|||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
||||||
struct WrapperShim__##__name { \
|
struct WrapperShim__##__name { \
|
||||||
static const char *kName; \
|
static const char* kName; \
|
||||||
template <typename... Args> \
|
template <typename... Args> \
|
||||||
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
|
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
|
||||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||||
return ::__name(args...); \
|
return ::__name(args...); \
|
||||||
} \
|
} \
|
||||||
} __name; \
|
} __name; \
|
||||||
const char *WrapperShim__##__name::kName = #__name;
|
const char* WrapperShim__##__name::kName = #__name;
|
||||||
|
|
||||||
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
||||||
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
||||||
@ -69,14 +69,14 @@ namespace wrap {
|
|||||||
|
|
||||||
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
||||||
struct DynLoadShim__##__name { \
|
struct DynLoadShim__##__name { \
|
||||||
static const char *kName; \
|
static const char* kName; \
|
||||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||||
static void *GetDsoHandle() { \
|
static void* GetDsoHandle() { \
|
||||||
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
|
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
|
||||||
return s.ValueOrDie(); \
|
return s.ValueOrDie(); \
|
||||||
} \
|
} \
|
||||||
static FuncPtrT LoadOrDie() { \
|
static FuncPtrT LoadOrDie() { \
|
||||||
void *f; \
|
void* f; \
|
||||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||||
kName, &f); \
|
kName, &f); \
|
||||||
CHECK(s.ok()) << "could not find " << kName \
|
CHECK(s.ok()) << "could not find " << kName \
|
||||||
@ -88,12 +88,12 @@ namespace wrap {
|
|||||||
return f; \
|
return f; \
|
||||||
} \
|
} \
|
||||||
template <typename... Args> \
|
template <typename... Args> \
|
||||||
rocblas_status operator()(GpuExecutor *parent, Args... args) { \
|
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
|
||||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||||
return DynLoad()(args...); \
|
return DynLoad()(args...); \
|
||||||
} \
|
} \
|
||||||
} __name; \
|
} __name; \
|
||||||
const char *DynLoadShim__##__name::kName = #__name;
|
const char* DynLoadShim__##__name::kName = #__name;
|
||||||
|
|
||||||
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
||||||
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
||||||
@ -322,7 +322,7 @@ bool ROCMBlas::Init() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
|
ROCMBlas::ROCMBlas(gpu::GpuExecutor* parent)
|
||||||
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
|
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
|
||||||
|
|
||||||
ROCMBlas::~ROCMBlas() {
|
ROCMBlas::~ROCMBlas() {
|
||||||
@ -1476,11 +1476,11 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
bool ROCMBlas::DoBlasGemm(Stream* stream, blas::Transpose transa,
|
||||||
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
|
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
|
||||||
float alpha, const DeviceMemory<Eigen::half> &a,
|
float alpha, const DeviceMemory<Eigen::half>& a,
|
||||||
int lda, const DeviceMemory<Eigen::half> &b, int ldb,
|
int lda, const DeviceMemory<Eigen::half>& b, int ldb,
|
||||||
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
|
float beta, DeviceMemory<Eigen::half>* c, int ldc) {
|
||||||
VLOG(1) << absl::StreamFormat(
|
VLOG(1) << absl::StreamFormat(
|
||||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
||||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||||
@ -1514,11 +1514,11 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||||||
return DoBlasInternal(
|
return DoBlasInternal(
|
||||||
wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
|
wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
|
||||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||||
reinterpret_cast<const rocblas_half *>(&alpha_half),
|
reinterpret_cast<const rocblas_half*>(&alpha_half),
|
||||||
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
|
reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda,
|
||||||
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
|
reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb,
|
||||||
reinterpret_cast<const rocblas_half *>(&beta_half),
|
reinterpret_cast<const rocblas_half*>(&beta_half),
|
||||||
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
|
reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||||
@ -1731,12 +1731,12 @@ bool ROCMBlas::GetBlasGemmAlgorithms(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
|
uint64 n, uint64 k, const HostOrDeviceScalar<int>& alpha,
|
||||||
const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
|
const DeviceMemory<int8>& a, int lda, const DeviceMemory<int8>& b, int ldb,
|
||||||
const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
|
const HostOrDeviceScalar<int>& beta, DeviceMemory<int32>* c, int ldc,
|
||||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
||||||
blas::ProfileResult *output_profile_result) {
|
blas::ProfileResult* output_profile_result) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||||
<< "for the \"int8\" dataype";
|
<< "for the \"int8\" dataype";
|
||||||
@ -1744,13 +1744,13 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
|
uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half>& alpha,
|
||||||
const DeviceMemory<Eigen::half> &a, int lda,
|
const DeviceMemory<Eigen::half>& a, int lda,
|
||||||
const DeviceMemory<Eigen::half> &b, int ldb,
|
const DeviceMemory<Eigen::half>& b, int ldb,
|
||||||
const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
|
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
|
||||||
int ldc, blas::ComputationType computation_type,
|
int ldc, blas::ComputationType computation_type,
|
||||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||||
<< "for the \"half\" dataype";
|
<< "for the \"half\" dataype";
|
||||||
@ -1758,12 +1758,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
|
uint64 n, uint64 k, const HostOrDeviceScalar<float>& alpha,
|
||||||
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
|
const DeviceMemory<float>& a, int lda, const DeviceMemory<float>& b,
|
||||||
int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
|
int ldb, const HostOrDeviceScalar<float>& beta, DeviceMemory<float>* c,
|
||||||
int ldc, blas::ComputationType computation_type,
|
int ldc, blas::ComputationType computation_type,
|
||||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||||
<< "for the \"float\" dataype";
|
<< "for the \"float\" dataype";
|
||||||
@ -1771,12 +1771,12 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
|
uint64 n, uint64 k, const HostOrDeviceScalar<double>& alpha,
|
||||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
|
const DeviceMemory<double>& a, int lda, const DeviceMemory<double>& b,
|
||||||
int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
|
int ldb, const HostOrDeviceScalar<double>& beta, DeviceMemory<double>* c,
|
||||||
int ldc, blas::ComputationType computation_type,
|
int ldc, blas::ComputationType computation_type,
|
||||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
blas::AlgorithmType algorithm, blas::ProfileResult* output_profile_result) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
<< "rocBLAS does not currently support the GEMMwithAlgorithm operation "
|
||||||
<< "for the \"double\" dataype";
|
<< "for the \"double\" dataype";
|
||||||
@ -1815,14 +1815,14 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
port::Status ROCMBlas::AllocateStridedBuffer(
|
port::Status ROCMBlas::AllocateStridedBuffer(
|
||||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
|
||||||
&raw_ptrs,
|
raw_ptrs,
|
||||||
int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
|
int batch_count, uint64_t batch_stride, ScratchAllocator* scratch_allocator,
|
||||||
Stream *stream,
|
Stream* stream,
|
||||||
std::unique_ptr<TemporaryDeviceMemory<
|
std::unique_ptr<TemporaryDeviceMemory<
|
||||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
|
||||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
|
||||||
*device_memory) {
|
device_memory) {
|
||||||
assert(device_memory != nullptr);
|
assert(device_memory != nullptr);
|
||||||
|
|
||||||
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
|
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
|
||||||
@ -1969,12 +1969,12 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemmBatched(
|
bool ROCMBlas::DoBlasGemmBatched(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, float alpha,
|
uint64 n, uint64 k, float alpha,
|
||||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
|
const port::ArraySlice<DeviceMemory<Eigen::half>*>& a, int lda,
|
||||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
|
const port::ArraySlice<DeviceMemory<Eigen::half>*>& b, int ldb, float beta,
|
||||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
|
const port::ArraySlice<DeviceMemory<Eigen::half>*>& c, int ldc,
|
||||||
int batch_count, ScratchAllocator *scratch_allocator) {
|
int batch_count, ScratchAllocator* scratch_allocator) {
|
||||||
const Eigen::half alpha_half(alpha);
|
const Eigen::half alpha_half(alpha);
|
||||||
const Eigen::half beta_half(beta);
|
const Eigen::half beta_half(beta);
|
||||||
|
|
||||||
@ -2299,7 +2299,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||||||
return DoBlasInternal(
|
return DoBlasInternal(
|
||||||
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
|
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
|
||||||
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
||||||
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
|
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float*>(GpuMemory(a)),
|
||||||
lda, GpuMemoryMutable(b), ldb);
|
lda, GpuMemoryMutable(b), ldb);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2311,7 +2311,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||||||
return DoBlasInternal(
|
return DoBlasInternal(
|
||||||
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
|
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
|
||||||
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
||||||
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
|
ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double*>(GpuMemory(a)),
|
||||||
lda, GpuMemoryMutable(b), ldb);
|
lda, GpuMemoryMutable(b), ldb);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2349,19 +2349,19 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||||||
wrap::rocblas_hgemm_strided_batched, stream,
|
wrap::rocblas_hgemm_strided_batched, stream,
|
||||||
false, /* pointer_mode_host */
|
false, /* pointer_mode_host */
|
||||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||||
reinterpret_cast<const rocblas_half *>(&alpha_half),
|
reinterpret_cast<const rocblas_half*>(&alpha_half),
|
||||||
reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda, stride_a,
|
reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda, stride_a,
|
||||||
reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb, stride_b,
|
reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb, stride_b,
|
||||||
reinterpret_cast<const rocblas_half *>(&beta_half),
|
reinterpret_cast<const rocblas_half*>(&beta_half),
|
||||||
reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc, stride_c,
|
reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc, stride_c,
|
||||||
batch_count);
|
batch_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
uint64 n, uint64 k, float alpha, const DeviceMemory<float>& a, int lda,
|
||||||
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
int64 stride_a, const DeviceMemory<float>& b, int ldb, int64 stride_b,
|
||||||
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
float beta, DeviceMemory<float>* c, int ldc, int64 stride_c,
|
||||||
int batch_count) {
|
int batch_count) {
|
||||||
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
||||||
false, /* pointer_mode_host */
|
false, /* pointer_mode_host */
|
||||||
@ -2371,10 +2371,10 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||||||
stride_c, batch_count);
|
stride_c, batch_count);
|
||||||
}
|
}
|
||||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
|
uint64 n, uint64 k, double alpha, const DeviceMemory<double>& a, int lda,
|
||||||
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
|
int64 stride_a, const DeviceMemory<double>& b, int ldb, int64 stride_b,
|
||||||
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
|
double beta, DeviceMemory<double>* c, int ldc, int64 stride_c,
|
||||||
int batch_count) {
|
int batch_count) {
|
||||||
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
||||||
false, /* pointer_mode_host */
|
false, /* pointer_mode_host */
|
||||||
@ -2384,11 +2384,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||||||
stride_c, batch_count);
|
stride_c, batch_count);
|
||||||
}
|
}
|
||||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, std::complex<float> alpha,
|
uint64 n, uint64 k, std::complex<float> alpha,
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
|
const DeviceMemory<std::complex<float>>& a, int lda, int64 stride_a,
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
|
const DeviceMemory<std::complex<float>>& b, int ldb, int64 stride_b,
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
std::complex<float> beta, DeviceMemory<std::complex<float>>* c, int ldc,
|
||||||
int64 stride_c, int batch_count) {
|
int64 stride_c, int batch_count) {
|
||||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
LOG(ERROR) << "rocBLAS does not currently support the "
|
||||||
"DoBlasGemmStridedBatched operation "
|
"DoBlasGemmStridedBatched operation "
|
||||||
@ -2396,11 +2396,11 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream* stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, std::complex<double> alpha,
|
uint64 n, uint64 k, std::complex<double> alpha,
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
|
const DeviceMemory<std::complex<double>>& a, int lda, int64 stride_a,
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
|
const DeviceMemory<std::complex<double>>& b, int ldb, int64 stride_b,
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
std::complex<double> beta, DeviceMemory<std::complex<double>>* c, int ldc,
|
||||||
int64 stride_c, int batch_count) {
|
int64 stride_c, int batch_count) {
|
||||||
LOG(ERROR) << "rocBLAS does not currently support the "
|
LOG(ERROR) << "rocBLAS does not currently support the "
|
||||||
"DoBlasGemmStridedBatched operation "
|
"DoBlasGemmStridedBatched operation "
|
||||||
@ -2418,10 +2418,10 @@ void initialize_rocblas() {
|
|||||||
PluginRegistry::Instance()
|
PluginRegistry::Instance()
|
||||||
->RegisterFactory<PluginRegistry::BlasFactory>(
|
->RegisterFactory<PluginRegistry::BlasFactory>(
|
||||||
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
|
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
|
||||||
[](internal::StreamExecutorInterface *parent)
|
[](internal::StreamExecutorInterface* parent)
|
||||||
-> blas::BlasSupport * {
|
-> blas::BlasSupport* {
|
||||||
gpu::GpuExecutor *rocm_executor =
|
gpu::GpuExecutor* rocm_executor =
|
||||||
dynamic_cast<gpu::GpuExecutor *>(parent);
|
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||||
if (rocm_executor == nullptr) {
|
if (rocm_executor == nullptr) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "Attempting to initialize an instance of the "
|
<< "Attempting to initialize an instance of the "
|
||||||
@ -2430,7 +2430,7 @@ void initialize_rocblas() {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
|
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
|
||||||
if (!blas->Init()) {
|
if (!blas->Init()) {
|
||||||
// Note: Init() will log a more specific error.
|
// Note: Init() will log a more specific error.
|
||||||
delete blas;
|
delete blas;
|
||||||
|
@ -62,7 +62,7 @@ class GpuExecutor;
|
|||||||
// Thread-safe post-initialization.
|
// Thread-safe post-initialization.
|
||||||
class ROCMBlas : public blas::BlasSupport {
|
class ROCMBlas : public blas::BlasSupport {
|
||||||
public:
|
public:
|
||||||
explicit ROCMBlas(GpuExecutor *parent);
|
explicit ROCMBlas(GpuExecutor* parent);
|
||||||
|
|
||||||
// Allocates a rocBLAS handle.
|
// Allocates a rocBLAS handle.
|
||||||
bool Init();
|
bool Init();
|
||||||
@ -98,7 +98,7 @@ class ROCMBlas : public blas::BlasSupport {
|
|||||||
// Convenience functions that call DoBlasInternalImpl with different values
|
// Convenience functions that call DoBlasInternalImpl with different values
|
||||||
// for err_on_failure.
|
// for err_on_failure.
|
||||||
template <typename FuncT, typename... Args>
|
template <typename FuncT, typename... Args>
|
||||||
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
|
bool DoBlasInternal(FuncT rocblas_func, Stream* stream,
|
||||||
bool pointer_mode_host, Args... args) {
|
bool pointer_mode_host, Args... args) {
|
||||||
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
||||||
/*err_on_failure=*/true, args...);
|
/*err_on_failure=*/true, args...);
|
||||||
@ -114,14 +114,14 @@ class ROCMBlas : public blas::BlasSupport {
|
|||||||
// strided flavor
|
// strided flavor
|
||||||
template <typename T>
|
template <typename T>
|
||||||
port::Status AllocateStridedBuffer(
|
port::Status AllocateStridedBuffer(
|
||||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
|
||||||
&raw_ptrs,
|
raw_ptrs,
|
||||||
int batch_count, uint64_t batch_stride,
|
int batch_count, uint64_t batch_stride,
|
||||||
ScratchAllocator *scratch_allocator, Stream *stream,
|
ScratchAllocator* scratch_allocator, Stream* stream,
|
||||||
std::unique_ptr<TemporaryDeviceMemory<
|
std::unique_ptr<TemporaryDeviceMemory<
|
||||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
|
||||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
|
||||||
*device_memory);
|
device_memory);
|
||||||
|
|
||||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||||
// types.
|
// types.
|
||||||
@ -185,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport {
|
|||||||
|
|
||||||
// GpuExecutor which instantiated this ROCMBlas.
|
// GpuExecutor which instantiated this ROCMBlas.
|
||||||
// Immutable post-initialization.
|
// Immutable post-initialization.
|
||||||
GpuExecutor *parent_;
|
GpuExecutor* parent_;
|
||||||
|
|
||||||
// rocBLAS library handle on the device.
|
// rocBLAS library handle on the device.
|
||||||
rocblas_handle blas_ GUARDED_BY(mu_);
|
rocblas_handle blas_ GUARDED_BY(mu_);
|
||||||
|
@ -27,10 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
#if defined(TENSORFLOW_USE_ROCM)
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace wrap {
|
namespace wrap {
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
@ -83,8 +79,8 @@ namespace wrap {
|
|||||||
__macro(hipDeviceTotalMem) \
|
__macro(hipDeviceTotalMem) \
|
||||||
__macro(hipDriverGetVersion) \
|
__macro(hipDriverGetVersion) \
|
||||||
__macro(hipEventCreateWithFlags) \
|
__macro(hipEventCreateWithFlags) \
|
||||||
__macro(hipEventElapsedTime) \
|
|
||||||
__macro(hipEventDestroy) \
|
__macro(hipEventDestroy) \
|
||||||
|
__macro(hipEventElapsedTime) \
|
||||||
__macro(hipEventQuery) \
|
__macro(hipEventQuery) \
|
||||||
__macro(hipEventRecord) \
|
__macro(hipEventRecord) \
|
||||||
__macro(hipEventSynchronize) \
|
__macro(hipEventSynchronize) \
|
||||||
|
@ -48,7 +48,7 @@ namespace wrap {
|
|||||||
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||||
struct WrapperShim__##__name { \
|
struct WrapperShim__##__name { \
|
||||||
template <typename... Args> \
|
template <typename... Args> \
|
||||||
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
hipfftResult operator()(GpuExecutor* parent, Args... args) { \
|
||||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||||
return ::__name(args...); \
|
return ::__name(args...); \
|
||||||
} \
|
} \
|
||||||
@ -86,21 +86,33 @@ namespace wrap {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define ROCFFT_ROUTINE_EACH(__macro) \
|
// clang-format off
|
||||||
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
|
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||||
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
|
__macro(hipfftDestroy) \
|
||||||
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
|
__macro(hipfftSetStream) \
|
||||||
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
|
__macro(hipfftPlan1d) \
|
||||||
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
|
__macro(hipfftPlan2d) \
|
||||||
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
|
__macro(hipfftPlan3d) \
|
||||||
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
|
__macro(hipfftPlanMany) \
|
||||||
__macro(hipfftMakePlanMany) \
|
__macro(hipfftCreate) \
|
||||||
__macro(hipfftExecD2Z) \
|
__macro(hipfftSetAutoAllocation) \
|
||||||
__macro(hipfftExecZ2D) \
|
__macro(hipfftSetWorkArea) \
|
||||||
__macro(hipfftExecC2C) \
|
__macro(hipfftGetSize1d) \
|
||||||
__macro(hipfftExecC2R) \
|
__macro(hipfftMakePlan1d) \
|
||||||
__macro(hipfftExecZ2Z) \
|
__macro(hipfftGetSize2d) \
|
||||||
__macro(hipfftExecR2C)
|
__macro(hipfftMakePlan2d) \
|
||||||
|
__macro(hipfftGetSize3d) \
|
||||||
|
__macro(hipfftMakePlan3d) \
|
||||||
|
__macro(hipfftGetSizeMany) \
|
||||||
|
__macro(hipfftMakePlanMany) \
|
||||||
|
__macro(hipfftExecD2Z) \
|
||||||
|
__macro(hipfftExecZ2D) \
|
||||||
|
__macro(hipfftExecC2C) \
|
||||||
|
__macro(hipfftExecC2R) \
|
||||||
|
__macro(hipfftExecZ2Z) \
|
||||||
|
__macro(hipfftExecR2C) \
|
||||||
|
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
|
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
|
||||||
|
|
||||||
@ -515,7 +527,7 @@ bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
GpuComplex(const_cast<InputT*>(GpuMemory(input))),
|
||||||
GpuComplex(GpuMemoryMutable(output)));
|
GpuComplex(GpuMemoryMutable(output)));
|
||||||
|
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
@ -542,7 +554,7 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
GpuComplex(const_cast<InputT*>(GpuMemory(input))),
|
||||||
GpuComplex(GpuMemoryMutable(output)),
|
GpuComplex(GpuMemoryMutable(output)),
|
||||||
rocm_fft_plan->GetFftDirection());
|
rocm_fft_plan->GetFftDirection());
|
||||||
|
|
||||||
@ -556,21 +568,21 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
|||||||
|
|
||||||
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
||||||
__fft_type3) \
|
__fft_type3) \
|
||||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
|
||||||
const DeviceMemory<std::complex<__type>> &input, \
|
const DeviceMemory<std::complex<__type>>& input, \
|
||||||
DeviceMemory<std::complex<__type>> *output) { \
|
DeviceMemory<std::complex<__type>>* output) { \
|
||||||
return DoFftWithDirectionInternal( \
|
return DoFftWithDirectionInternal( \
|
||||||
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
|
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
|
||||||
} \
|
} \
|
||||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
|
||||||
const DeviceMemory<__type> &input, \
|
const DeviceMemory<__type>& input, \
|
||||||
DeviceMemory<std::complex<__type>> *output) { \
|
DeviceMemory<std::complex<__type>>* output) { \
|
||||||
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
|
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
|
||||||
output); \
|
output); \
|
||||||
} \
|
} \
|
||||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
bool ROCMFft::DoFft(Stream* stream, fft::Plan* plan, \
|
||||||
const DeviceMemory<std::complex<__type>> &input, \
|
const DeviceMemory<std::complex<__type>>& input, \
|
||||||
DeviceMemory<__type> *output) { \
|
DeviceMemory<__type>* output) { \
|
||||||
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
|
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
|
||||||
output); \
|
output); \
|
||||||
}
|
}
|
||||||
@ -590,9 +602,9 @@ void initialize_rocfft() {
|
|||||||
port::Status status =
|
port::Status status =
|
||||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||||
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
||||||
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
|
[](internal::StreamExecutorInterface* parent) -> fft::FftSupport* {
|
||||||
gpu::GpuExecutor *rocm_executor =
|
gpu::GpuExecutor* rocm_executor =
|
||||||
dynamic_cast<gpu::GpuExecutor *>(parent);
|
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||||
if (rocm_executor == nullptr) {
|
if (rocm_executor == nullptr) {
|
||||||
LOG(ERROR)
|
LOG(ERROR)
|
||||||
<< "Attempting to initialize an instance of the rocFFT "
|
<< "Attempting to initialize an instance of the rocFFT "
|
||||||
|
@ -14,12 +14,11 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "rocm/include/hiprand/hiprand.h"
|
#include "rocm/include/hiprand/hiprand.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||||
#include "tensorflow/stream_executor/lib/env.h"
|
#include "tensorflow/stream_executor/lib/env.h"
|
||||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||||
|
Loading…
Reference in New Issue
Block a user