diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 1ebfe235b4d..2b15b12ec24 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library") @@ -38,7 +38,7 @@ cc_library( ":xla_cpu_device", ":xla_cpu_jit", "//tensorflow/compiler/plugin", - ] + if_cuda([ + ] + if_cuda_or_rocm([ ":xla_gpu_device", ":xla_gpu_jit", ]), @@ -61,7 +61,7 @@ cc_library( cc_library( name = "xla_gpu_jit", visibility = ["//visibility:public"], - deps = if_cuda([ + deps = if_cuda_or_rocm([ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9f709ddb058..b188f04c74e 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -456,6 +456,7 @@ tf_cc_test( cc_library( name = "gpu_executable", srcs = [ + "cholesky_thunk.cc", "collective_permute_thunk.cc", "conditional_thunk.cc", "convolution_thunk.cc", @@ -476,10 +477,9 @@ cc_library( "triangular_solve_thunk.cc", "tuple_thunk.cc", "while_thunk.cc", - ] + if_cuda_is_configured([ - "cholesky_thunk.cc", - ]), + ], hdrs = [ + "cholesky_thunk.h", "collective_permute_thunk.h", "conditional_thunk.h", "convolution_thunk.h", @@ -500,12 +500,11 @@ cc_library( "triangular_solve_thunk.h", "tuple_thunk.h", "while_thunk.h", - ] + if_cuda_is_configured([ - "cholesky_thunk.h", - ]), + ], deps = [ ":backend_configs", ":buffer_allocations", + ":cusolver_context", ":cudnn_conv_runner", ":gpu_debug_info_manager", ":gpu_types", @@ -559,7 +558,6 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ] + if_cuda_is_configured([ - ":cusolver_context", "//tensorflow/stream_executor/cuda:cuda_stream", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", @@ -731,21 +729,22 @@ tf_cc_test( cc_library( name = "cusolver_context", - srcs = ["cusolver_context.cc"], + srcs = if_cuda_is_configured(["cusolver_context.cc"]), hdrs = ["cusolver_context.h"], deps = [ # LINT.IfChange "@local_config_cuda//cuda:cublas_headers", # LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers) - "@local_config_cuda//cuda:cuda_headers", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", "//tensorflow/stream_executor/cuda:cusolver_lib", - ], + ]), ) cc_library( @@ -1053,7 +1052,6 @@ cc_library( deps = [ ":alias_passthrough_params", ":cudnn_batchnorm_rewriter", - ":cudnn_conv_algorithm_picker", ":cudnn_conv_padding_legalization", ":cudnn_conv_rewriter", ":fusion_merger", diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 60301b4de64..2fe359861f8 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc index 4103a720c98..b18170b00e4 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.cc +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.cc @@ -169,7 +169,8 @@ StatusOr CusolverContext::PotrfBufferSize(PrimitiveType type, } #define POTRF_INSTANCE(T, type_prefix) \ - Status CusolverContext::Potrf( \ + template <> \ + Status CusolverContext::Potrf( \ se::blas::UpperLower uplo, int n, se::DeviceMemory A, int lda, \ se::DeviceMemory lapack_info, se::DeviceMemory workspace) { \ return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \ diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h index c3d075c47c7..dfe55188b18 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.h +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h @@ -18,8 +18,10 @@ limitations under the License. #include -#include "third_party/gpus/cuda/include/cublas_v2.h" +#if !TENSORFLOW_USE_ROCM #include "third_party/gpus/cuda/include/cusolverDn.h" +#endif + #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -30,6 +32,8 @@ limitations under the License. namespace xla { namespace gpu { +#if !TENSORFLOW_USE_ROCM + class CusolverContext { public: // stream may be nullptr, in which case the context can only be used for @@ -43,26 +47,17 @@ class CusolverContext { CusolverContext& operator=(const CusolverContext&) = delete; CusolverContext& operator=(CusolverContext&&); - se::Stream* stream() const { return stream_; } - cusolverDnHandle_t handle() const { return handle_; } - // Computes the Cholesky factorization A = L * L^T for a single matrix. // Returns Status::OK() if the kernel was launched successfully. See: // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf - Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, + template ::value || + std::is_same::value || + std::is_same>::value || + std::is_same>::value>> + Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, int lda, se::DeviceMemory dev_lapack_info, - se::DeviceMemory workspace); - Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, - int lda, se::DeviceMemory dev_lapack_info, - se::DeviceMemory workspace); - Status Potrf(se::blas::UpperLower uplo, int n, - se::DeviceMemory> dev_A, int lda, - se::DeviceMemory dev_lapack_info, - se::DeviceMemory> workspace); - Status Potrf(se::blas::UpperLower uplo, int n, - se::DeviceMemory> dev_A, int lda, - se::DeviceMemory dev_lapack_info, - se::DeviceMemory> workspace); + se::DeviceMemory workspace); // Returns the size of the `workspace` required by Potrf, in number of // elements of `type`. @@ -72,10 +67,42 @@ class CusolverContext { private: CusolverContext(se::Stream* stream, cusolverDnHandle_t handle); + cusolverDnHandle_t handle() const { return handle_; } + se::Stream* stream_ = nullptr; cusolverDnHandle_t handle_ = nullptr; }; +#else + +typedef void* cusolverDnHandle_t; + +// TODO(cheshire): Remove this hack once we have ROCM implementation. +class CusolverContext { + public: + static StatusOr Create(se::Stream* stream) { + LOG(FATAL) << "Unimplemented"; + } + + template ::value || + std::is_same::value || + std::is_same>::value || + std::is_same>::value>> + Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory dev_A, + int lda, se::DeviceMemory dev_lapack_info, + se::DeviceMemory workspace) { + LOG(FATAL) << "Unimplemented"; + } + + StatusOr PotrfBufferSize(PrimitiveType type, se::blas::UpperLower uplo, + int n, int lda) { + LOG(FATAL) << "Unimplemented"; + } +}; + +#endif + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index 65673106391..85571804315 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" #include "absl/strings/str_format.h" -#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/gpu/gpu_stream.h" namespace xla { diff --git a/tensorflow/stream_executor/build_defs.bzl b/tensorflow/stream_executor/build_defs.bzl index 575ff639e75..3cb24f8468f 100644 --- a/tensorflow/stream_executor/build_defs.bzl +++ b/tensorflow/stream_executor/build_defs.bzl @@ -18,3 +18,6 @@ def if_gpu_is_configured(x): if cuda_is_configured() or rocm_is_configured(): return x return [] + +def if_cuda_or_rocm(x): + return if_gpu_is_configured(x)