Merge pull request #31491 from ROCmSoftwarePlatform:google_upstream_xla_amdgpu_enable
PiperOrigin-RevId: 268497286
This commit is contained in:
commit
8ee2467401
@ -1,5 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
@ -38,7 +38,7 @@ cc_library(
|
||||
":xla_cpu_device",
|
||||
":xla_cpu_jit",
|
||||
"//tensorflow/compiler/plugin",
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([
|
||||
":xla_gpu_device",
|
||||
":xla_gpu_jit",
|
||||
]),
|
||||
@ -61,7 +61,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "xla_gpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda([
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
|
||||
@ -456,6 +456,7 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "gpu_executable",
|
||||
srcs = [
|
||||
"cholesky_thunk.cc",
|
||||
"collective_permute_thunk.cc",
|
||||
"conditional_thunk.cc",
|
||||
"convolution_thunk.cc",
|
||||
@ -476,10 +477,9 @@ cc_library(
|
||||
"triangular_solve_thunk.cc",
|
||||
"tuple_thunk.cc",
|
||||
"while_thunk.cc",
|
||||
] + if_cuda_is_configured([
|
||||
"cholesky_thunk.cc",
|
||||
]),
|
||||
],
|
||||
hdrs = [
|
||||
"cholesky_thunk.h",
|
||||
"collective_permute_thunk.h",
|
||||
"conditional_thunk.h",
|
||||
"convolution_thunk.h",
|
||||
@ -500,12 +500,11 @@ cc_library(
|
||||
"triangular_solve_thunk.h",
|
||||
"tuple_thunk.h",
|
||||
"while_thunk.h",
|
||||
] + if_cuda_is_configured([
|
||||
"cholesky_thunk.h",
|
||||
]),
|
||||
],
|
||||
deps = [
|
||||
":backend_configs",
|
||||
":buffer_allocations",
|
||||
":cusolver_context",
|
||||
":cudnn_conv_runner",
|
||||
":gpu_debug_info_manager",
|
||||
":gpu_types",
|
||||
@ -559,7 +558,6 @@ cc_library(
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
] + if_cuda_is_configured([
|
||||
":cusolver_context",
|
||||
"//tensorflow/stream_executor/cuda:cuda_stream",
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
|
||||
@ -731,21 +729,22 @@ tf_cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_context",
|
||||
srcs = ["cusolver_context.cc"],
|
||||
srcs = if_cuda_is_configured(["cusolver_context.cc"]),
|
||||
hdrs = ["cusolver_context.h"],
|
||||
deps = [
|
||||
# LINT.IfChange
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers)
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:blas",
|
||||
] + if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/cuda:cusolver_lib",
|
||||
],
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -1053,7 +1052,6 @@ cc_library(
|
||||
deps = [
|
||||
":alias_passthrough_params",
|
||||
":cudnn_batchnorm_rewriter",
|
||||
":cudnn_conv_algorithm_picker",
|
||||
":cudnn_conv_padding_legalization",
|
||||
":cudnn_conv_rewriter",
|
||||
":fusion_merger",
|
||||
|
||||
@ -22,7 +22,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
|
||||
@ -169,7 +169,8 @@ StatusOr<int64> CusolverContext::PotrfBufferSize(PrimitiveType type,
|
||||
}
|
||||
|
||||
#define POTRF_INSTANCE(T, type_prefix) \
|
||||
Status CusolverContext::Potrf( \
|
||||
template <> \
|
||||
Status CusolverContext::Potrf<T>( \
|
||||
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
|
||||
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
|
||||
return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
|
||||
|
||||
@ -18,8 +18,10 @@ limitations under the License.
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#if !TENSORFLOW_USE_ROCM
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -30,6 +32,8 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
#if !TENSORFLOW_USE_ROCM
|
||||
|
||||
class CusolverContext {
|
||||
public:
|
||||
// stream may be nullptr, in which case the context can only be used for
|
||||
@ -43,26 +47,17 @@ class CusolverContext {
|
||||
CusolverContext& operator=(const CusolverContext&) = delete;
|
||||
CusolverContext& operator=(CusolverContext&&);
|
||||
|
||||
se::Stream* stream() const { return stream_; }
|
||||
cusolverDnHandle_t handle() const { return handle_; }
|
||||
|
||||
// Computes the Cholesky factorization A = L * L^T for a single matrix.
|
||||
// Returns Status::OK() if the kernel was launched successfully. See:
|
||||
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
|
||||
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<float> dev_A,
|
||||
template <typename T, typename = std::enable_if_t<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, std::complex<float>>::value ||
|
||||
std::is_same<T, std::complex<double>>::value>>
|
||||
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<T> dev_A,
|
||||
int lda, se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<float> workspace);
|
||||
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<double> dev_A,
|
||||
int lda, se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<double> workspace);
|
||||
Status Potrf(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<std::complex<float>> dev_A, int lda,
|
||||
se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<std::complex<float>> workspace);
|
||||
Status Potrf(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<std::complex<double>> dev_A, int lda,
|
||||
se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<std::complex<double>> workspace);
|
||||
se::DeviceMemory<T> workspace);
|
||||
|
||||
// Returns the size of the `workspace` required by Potrf, in number of
|
||||
// elements of `type`.
|
||||
@ -72,10 +67,42 @@ class CusolverContext {
|
||||
private:
|
||||
CusolverContext(se::Stream* stream, cusolverDnHandle_t handle);
|
||||
|
||||
cusolverDnHandle_t handle() const { return handle_; }
|
||||
|
||||
se::Stream* stream_ = nullptr;
|
||||
cusolverDnHandle_t handle_ = nullptr;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
typedef void* cusolverDnHandle_t;
|
||||
|
||||
// TODO(cheshire): Remove this hack once we have ROCM implementation.
|
||||
class CusolverContext {
|
||||
public:
|
||||
static StatusOr<CusolverContext> Create(se::Stream* stream) {
|
||||
LOG(FATAL) << "Unimplemented";
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, std::complex<float>>::value ||
|
||||
std::is_same<T, std::complex<double>>::value>>
|
||||
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<T> dev_A,
|
||||
int lda, se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<T> workspace) {
|
||||
LOG(FATAL) << "Unimplemented";
|
||||
}
|
||||
|
||||
StatusOr<int64> PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo,
|
||||
int n, int lda) {
|
||||
LOG(FATAL) << "Unimplemented";
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -18,3 +18,6 @@ def if_gpu_is_configured(x):
|
||||
if cuda_is_configured() or rocm_is_configured():
|
||||
return x
|
||||
return []
|
||||
|
||||
def if_cuda_or_rocm(x):
|
||||
return if_gpu_is_configured(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user