Support for complex type GEMM and GEMV

This commit is contained in:
Eugene Kuznetsov 2019-12-23 23:07:25 -08:00
parent 77994fbf12
commit c329f1c502
2 changed files with 317 additions and 191 deletions

File diff suppressed because it is too large Load Diff

View File

@ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper<Eigen::half> {
using mapped_type = rocblas_half; using mapped_type = rocblas_half;
}; };
template <>
struct RocBlasTypeConversionHelper<std::complex<float> > {
using mapped_type = rocblas_float_complex;
};
template <>
struct RocBlasTypeConversionHelper<std::complex<double> > {
using mapped_type = rocblas_double_complex;
};
// Opaque and unique identifier for the rocBLAS plugin. // Opaque and unique identifier for the rocBLAS plugin.
extern const PluginId kRocBlasPlugin; extern const PluginId kRocBlasPlugin;
@ -110,7 +120,7 @@ class ROCMBlas : public blas::BlasSupport {
/*err_on_failure=*/false, args...); /*err_on_failure=*/false, args...);
} }
// A helper allocation funciton to convert raw pointers memory layout to // A helper allocation function to convert raw pointers memory layout to
// strided flavor // strided flavor
template <typename T> template <typename T>
port::Status AllocateStridedBuffer( port::Status AllocateStridedBuffer(
@ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport {
std::unique_ptr<TemporaryDeviceMemory< std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory); *device_memory, bool copy_data,
bool& reallocated);
// A helper function to implement DoBlasGemmBatched interfaces for generic // A helper function to implement DoBlasGemmBatched interfaces for generic
// types. // types.