Support for complex type GEMM and GEMV
This commit is contained in:
parent
77994fbf12
commit
c329f1c502
File diff suppressed because it is too large
Load Diff
@ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper<Eigen::half> {
|
||||
using mapped_type = rocblas_half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RocBlasTypeConversionHelper<std::complex<float> > {
|
||||
using mapped_type = rocblas_float_complex;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RocBlasTypeConversionHelper<std::complex<double> > {
|
||||
using mapped_type = rocblas_double_complex;
|
||||
};
|
||||
|
||||
// Opaque and unique identifier for the rocBLAS plugin.
|
||||
extern const PluginId kRocBlasPlugin;
|
||||
|
||||
@ -110,7 +120,7 @@ class ROCMBlas : public blas::BlasSupport {
|
||||
/*err_on_failure=*/false, args...);
|
||||
}
|
||||
|
||||
// A helper allocation funciton to convert raw pointers memory layout to
|
||||
// A helper allocation function to convert raw pointers memory layout to
|
||||
// strided flavor
|
||||
template <typename T>
|
||||
port::Status AllocateStridedBuffer(
|
||||
@ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport {
|
||||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory);
|
||||
*device_memory, bool copy_data,
|
||||
bool& reallocated);
|
||||
|
||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||
// types.
|
||||
|
Loading…
Reference in New Issue
Block a user