Merge pull request #31491 from ROCmSoftwarePlatform:google_upstream_xla_amdgpu_enable

PiperOrigin-RevId: 268497286
This commit is contained in:
TensorFlower Gardener 2019-09-11 11:52:50 -07:00
commit 8ee2467401
7 changed files with 61 additions and 34 deletions

View File

@ -1,5 +1,5 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
@ -38,7 +38,7 @@ cc_library(
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/compiler/plugin",
] + if_cuda([
] + if_cuda_or_rocm([
":xla_gpu_device",
":xla_gpu_jit",
]),
@ -61,7 +61,7 @@ cc_library(
cc_library(
name = "xla_gpu_jit",
visibility = ["//visibility:public"],
deps = if_cuda([
deps = if_cuda_or_rocm([
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",

View File

@ -456,6 +456,7 @@ tf_cc_test(
cc_library(
name = "gpu_executable",
srcs = [
"cholesky_thunk.cc",
"collective_permute_thunk.cc",
"conditional_thunk.cc",
"convolution_thunk.cc",
@ -476,10 +477,9 @@ cc_library(
"triangular_solve_thunk.cc",
"tuple_thunk.cc",
"while_thunk.cc",
] + if_cuda_is_configured([
"cholesky_thunk.cc",
]),
],
hdrs = [
"cholesky_thunk.h",
"collective_permute_thunk.h",
"conditional_thunk.h",
"convolution_thunk.h",
@ -500,12 +500,11 @@ cc_library(
"triangular_solve_thunk.h",
"tuple_thunk.h",
"while_thunk.h",
] + if_cuda_is_configured([
"cholesky_thunk.h",
]),
],
deps = [
":backend_configs",
":buffer_allocations",
":cusolver_context",
":cudnn_conv_runner",
":gpu_debug_info_manager",
":gpu_types",
@ -559,7 +558,6 @@ cc_library(
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
] + if_cuda_is_configured([
":cusolver_context",
"//tensorflow/stream_executor/cuda:cuda_stream",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
@ -731,21 +729,22 @@ tf_cc_test(
cc_library(
name = "cusolver_context",
srcs = ["cusolver_context.cc"],
srcs = if_cuda_is_configured(["cusolver_context.cc"]),
hdrs = ["cusolver_context.h"],
deps = [
# LINT.IfChange
"@local_config_cuda//cuda:cublas_headers",
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers)
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:blas",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/stream_executor/cuda:cusolver_lib",
],
]),
)
cc_library(
@ -1053,7 +1052,6 @@ cc_library(
deps = [
":alias_passthrough_params",
":cudnn_batchnorm_rewriter",
":cudnn_conv_algorithm_picker",
":cudnn_conv_padding_legalization",
":cudnn_conv_rewriter",
":fusion_merger",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"

View File

@ -169,7 +169,8 @@ StatusOr<int64> CusolverContext::PotrfBufferSize(PrimitiveType type,
}
#define POTRF_INSTANCE(T, type_prefix) \
Status CusolverContext::Potrf( \
template <> \
Status CusolverContext::Potrf<T>( \
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \

View File

@ -18,8 +18,10 @@ limitations under the License.
#include <complex>
#include "third_party/gpus/cuda/include/cublas_v2.h"
#if !TENSORFLOW_USE_ROCM
#include "third_party/gpus/cuda/include/cusolverDn.h"
#endif
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@ -30,6 +32,8 @@ limitations under the License.
namespace xla {
namespace gpu {
#if !TENSORFLOW_USE_ROCM
class CusolverContext {
public:
// stream may be nullptr, in which case the context can only be used for
@ -43,26 +47,17 @@ class CusolverContext {
CusolverContext& operator=(const CusolverContext&) = delete;
CusolverContext& operator=(CusolverContext&&);
se::Stream* stream() const { return stream_; }
cusolverDnHandle_t handle() const { return handle_; }
// Computes the Cholesky factorization A = L * L^T for a single matrix.
// Returns Status::OK() if the kernel was launched successfully. See:
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<float> dev_A,
template <typename T, typename = std::enable_if_t<
std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value>>
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<T> dev_A,
int lda, se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<float> workspace);
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<double> dev_A,
int lda, se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<double> workspace);
Status Potrf(se::blas::UpperLower uplo, int n,
se::DeviceMemory<std::complex<float>> dev_A, int lda,
se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<std::complex<float>> workspace);
Status Potrf(se::blas::UpperLower uplo, int n,
se::DeviceMemory<std::complex<double>> dev_A, int lda,
se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<std::complex<double>> workspace);
se::DeviceMemory<T> workspace);
// Returns the size of the `workspace` required by Potrf, in number of
// elements of `type`.
@ -72,10 +67,42 @@ class CusolverContext {
private:
CusolverContext(se::Stream* stream, cusolverDnHandle_t handle);
cusolverDnHandle_t handle() const { return handle_; }
se::Stream* stream_ = nullptr;
cusolverDnHandle_t handle_ = nullptr;
};
#else
typedef void* cusolverDnHandle_t;
// TODO(cheshire): Remove this hack once we have ROCM implementation.
class CusolverContext {
public:
static StatusOr<CusolverContext> Create(se::Stream* stream) {
LOG(FATAL) << "Unimplemented";
}
template <typename T, typename = std::enable_if_t<
std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value>>
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<T> dev_A,
int lda, se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<T> workspace) {
LOG(FATAL) << "Unimplemented";
}
StatusOr<int64> PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo,
int n, int lda) {
LOG(FATAL) << "Unimplemented";
}
};
#endif
} // namespace gpu
} // namespace xla

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
#include "absl/strings/str_format.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
namespace xla {

View File

@ -18,3 +18,6 @@ def if_gpu_is_configured(x):
if cuda_is_configured() or rocm_is_configured():
return x
return []
def if_cuda_or_rocm(x):
return if_gpu_is_configured(x)