Merge pull request #29153 from ROCmSoftwarePlatform:google_upstream_cuda_solvers

PiperOrigin-RevId: 252566729
This commit is contained in:
TensorFlower Gardener 2019-06-11 00:58:06 -07:00
commit 7b924fb2a7

View File

@ -21,13 +21,15 @@ limitations under the License.
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
// kernels. // kernels.
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <functional> #include <functional>
#include <vector> #include <vector>
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolverDn.h"
#endif
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -35,6 +37,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
#if GOOGLE_CUDA
// Type traits to get CUDA complex types from std::complex<T>. // Type traits to get CUDA complex types from std::complex<T>.
template <typename T> template <typename T>
struct CUDAComplexT { struct CUDAComplexT {
@ -327,6 +330,7 @@ class CudaSolver {
TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver); TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
}; };
#endif // GOOGLE_CUDA
// Helper class to allocate scratch memory and keep track of debug info. // Helper class to allocate scratch memory and keep track of debug info.
// Mostly a thin wrapper around Tensor & allocate_temp. // Mostly a thin wrapper around Tensor & allocate_temp.
@ -416,6 +420,7 @@ class DeviceLapackInfo : public ScratchSpace<int> {
} }
}; };
#if GOOGLE_CUDA
template <typename Scalar> template <typename Scalar>
ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape, ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
const string& debug_info, const string& debug_info,
@ -438,9 +443,10 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
scratch_tensor_refs_.emplace_back(new_dev_info.tensor()); scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
return new_dev_info; return new_dev_info;
} }
#endif // GOOGLE_CUDA
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_ #endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_