Support for complex type GEMM and GEMV

This commit is contained in:
Eugene Kuznetsov 2019-12-23 23:07:25 -08:00
parent 77994fbf12
commit c329f1c502
2 changed files with 317 additions and 191 deletions

File diff suppressed because it is too large Load Diff

View File

@ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper<Eigen::half> {
using mapped_type = rocblas_half;
};
template <>
struct RocBlasTypeConversionHelper<std::complex<float> > {
using mapped_type = rocblas_float_complex;
};
template <>
struct RocBlasTypeConversionHelper<std::complex<double> > {
using mapped_type = rocblas_double_complex;
};
// Opaque and unique identifier for the rocBLAS plugin.
extern const PluginId kRocBlasPlugin;
@ -110,7 +120,7 @@ class ROCMBlas : public blas::BlasSupport {
/*err_on_failure=*/false, args...);
}
// A helper allocation funciton to convert raw pointers memory layout to
// A helper allocation function to convert raw pointers memory layout to
// strided flavor
template <typename T>
port::Status AllocateStridedBuffer(
@ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport {
std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory);
*device_memory, bool copy_data,
bool& reallocated);
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.