Merge pull request #29153 from ROCmSoftwarePlatform:google_upstream_cuda_solvers
PiperOrigin-RevId: 252566729
This commit is contained in:
commit
7b924fb2a7
@ -21,13 +21,15 @@ limitations under the License.
|
|||||||
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
|
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
|
||||||
// kernels.
|
// kernels.
|
||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||||
|
#endif
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -35,6 +37,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
// Type traits to get CUDA complex types from std::complex<T>.
|
// Type traits to get CUDA complex types from std::complex<T>.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct CUDAComplexT {
|
struct CUDAComplexT {
|
||||||
@ -327,6 +330,7 @@ class CudaSolver {
|
|||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
|
TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
|
||||||
};
|
};
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
// Helper class to allocate scratch memory and keep track of debug info.
|
// Helper class to allocate scratch memory and keep track of debug info.
|
||||||
// Mostly a thin wrapper around Tensor & allocate_temp.
|
// Mostly a thin wrapper around Tensor & allocate_temp.
|
||||||
@ -416,6 +420,7 @@ class DeviceLapackInfo : public ScratchSpace<int> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
|
ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
|
||||||
const string& debug_info,
|
const string& debug_info,
|
||||||
@ -438,9 +443,10 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
|
|||||||
scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
|
scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
|
||||||
return new_dev_info;
|
return new_dev_info;
|
||||||
}
|
}
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
|
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user