Merge pull request from ROCmSoftwarePlatform:google_upstream_cuda_solvers

PiperOrigin-RevId: 252566729
This commit is contained in:
TensorFlower Gardener 2019-06-11 00:58:06 -07:00
commit 7b924fb2a7

View File

@ -21,13 +21,15 @@ limitations under the License.
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
// kernels.
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <functional>
#include <vector>
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@ -35,6 +37,7 @@ limitations under the License.
namespace tensorflow {
#if GOOGLE_CUDA
// Type traits to get CUDA complex types from std::complex<T>.
template <typename T>
struct CUDAComplexT {
@ -327,6 +330,7 @@ class CudaSolver {
TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
};
#endif // GOOGLE_CUDA
// Helper class to allocate scratch memory and keep track of debug info.
// Mostly a thin wrapper around Tensor & allocate_temp.
@ -416,6 +420,7 @@ class DeviceLapackInfo : public ScratchSpace<int> {
}
};
#if GOOGLE_CUDA
template <typename Scalar>
ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
const string& debug_info,
@ -438,9 +443,10 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
return new_dev_info;
}
#endif // GOOGLE_CUDA
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_