Merge pull request #29153 from ROCmSoftwarePlatform:google_upstream_cuda_solvers
PiperOrigin-RevId: 252566729
This commit is contained in:
commit
7b924fb2a7
@ -21,13 +21,15 @@ limitations under the License.
|
||||
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
|
||||
// kernels.
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -35,6 +37,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Type traits to get CUDA complex types from std::complex<T>.
|
||||
template <typename T>
|
||||
struct CUDAComplexT {
|
||||
@ -327,6 +330,7 @@ class CudaSolver {
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// Helper class to allocate scratch memory and keep track of debug info.
|
||||
// Mostly a thin wrapper around Tensor & allocate_temp.
|
||||
@ -416,6 +420,7 @@ class DeviceLapackInfo : public ScratchSpace<int> {
|
||||
}
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
template <typename Scalar>
|
||||
ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
|
||||
const string& debug_info,
|
||||
@ -438,9 +443,10 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
|
||||
scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
|
||||
return new_dev_info;
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user