Use BlasGemv() when autotune is not set.
PiperOrigin-RevId: 163754092
This commit is contained in:
parent
724884f1ca
commit
d03ba54f72
@ -212,6 +212,11 @@ bool LaunchBlasGemv<Eigen::half>::IsSupported() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool ShouldUseGemv(uint64 n) {
|
||||||
|
return (LaunchBlasGemv<T>::IsSupported() && n == 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool GetCublasAutotuneComputationType(
|
bool GetCublasAutotuneComputationType(
|
||||||
@ -339,7 +344,7 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Try BlasGemvWithProfiling
|
// Try BlasGemvWithProfiling
|
||||||
if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
|
if (ShouldUseGemv<T>(n)) {
|
||||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
||||||
transpose_a ? m : k, transpose_a ? k : m,
|
transpose_a ? m : k, transpose_a ? k : m,
|
||||||
a_ptr, b_ptr, &c_ptr, &profile_result);
|
a_ptr, b_ptr, &c_ptr, &profile_result);
|
||||||
@ -385,11 +390,17 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
|||||||
// 2) compute type does not support autotune;
|
// 2) compute type does not support autotune;
|
||||||
// 3) no algorithm is found;
|
// 3) no algorithm is found;
|
||||||
// 4) all internal kernels in autotune return invalid results.
|
// 4) all internal kernels in autotune return invalid results.
|
||||||
|
// For the following case, we use normal BlasGemv():
|
||||||
|
// 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
|
||||||
|
// and n == 1.
|
||||||
|
// 2) We set the use_autotune flag and it picked up BlasGemv() and set the
|
||||||
|
// algorithm_config.algorithm() to be kDefaultBlasGemv.
|
||||||
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
|
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
|
||||||
algorithm_config.algorithm() == kNoAlgorithm ||
|
algorithm_config.algorithm() == kNoAlgorithm ||
|
||||||
algorithm_config.algorithm() == kDefaultBlasGemm ||
|
algorithm_config.algorithm() == kDefaultBlasGemm ||
|
||||||
algorithm_config.algorithm() == kDefaultBlasGemv) {
|
algorithm_config.algorithm() == kDefaultBlasGemv) {
|
||||||
if (algorithm_config.algorithm() == kDefaultBlasGemv) {
|
if (algorithm_config.algorithm() == kDefaultBlasGemv ||
|
||||||
|
ShouldUseGemv<T>(n)) {
|
||||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||||
// Here we are multiplying in the natural order, so we have to flip
|
// Here we are multiplying in the natural order, so we have to flip
|
||||||
// the transposition flag to compensate for the tensor being stored
|
// the transposition flag to compensate for the tensor being stored
|
||||||
|
Loading…
Reference in New Issue
Block a user