[XLA:GPU] Add implementation of Cholesky that calls into cuSolver.

PiperOrigin-RevId: 236123818
This commit is contained in:
Peter Hawkins 2019-02-28 08:07:05 -08:00 committed by TensorFlower Gardener
parent fdbaab6f50
commit f38eea2aec
18 changed files with 957 additions and 47 deletions

View File

@ -326,6 +326,7 @@ tf_cuda_library(
cc_library(
name = "gpu_executable",
srcs = [
"cholesky_thunk.cc",
"conditional_thunk.cc",
"convolution_thunk.cc",
"copy_thunk.cc",
@ -345,6 +346,7 @@ cc_library(
"while_thunk.cc",
],
hdrs = [
"cholesky_thunk.h",
"conditional_thunk.h",
"convolution_thunk.h",
"copy_thunk.h",
@ -366,6 +368,7 @@ cc_library(
deps = [
":buffer_allocations",
":cudnn_conv_runner",
":cusolver_context",
":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
@ -404,6 +407,7 @@ cc_library(
"//tensorflow/stream_executor",
"//tensorflow/stream_executor:blas",
"//tensorflow/stream_executor:device_memory",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
@ -442,6 +446,7 @@ cc_library(
":cudnn_conv_runner",
":gpu_executable",
":ir_emission_utils",
":scratch_allocator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla/service:compiler",
@ -459,6 +464,18 @@ cc_library(
],
)
cc_library(
name = "scratch_allocator",
srcs = ["scratch_allocator.cc"],
hdrs = ["scratch_allocator.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library(
name = "cudnn_conv_runner",
srcs = ["cudnn_conv_runner.cc"],
@ -515,6 +532,43 @@ tf_cc_test(
],
)
cc_library(
name = "cusolver_context",
srcs = ["cusolver_context.cc"],
hdrs = ["cusolver_context.h"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:blas",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cusolver",
],
)
cc_library(
name = "cusolver_rewriter",
srcs = ["cusolver_rewriter.cc"],
hdrs = ["cusolver_rewriter.h"],
deps = [
":cusolver_context",
":ir_emission_utils",
":scratch_allocator",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:blas",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "instruction_fusion",
srcs = ["instruction_fusion.cc"],
@ -748,11 +802,13 @@ cc_library(
srcs = ["nvptx_compiler.cc"],
hdrs = ["nvptx_compiler.h"],
deps = [
":cudnn_batchnorm_rewriter",
":cudnn_conv_algorithm_picker",
":cudnn_conv_pad_for_tensor_cores",
":cudnn_conv_padding_legalization",
":cudnn_conv_rewriter",
":cudnn_fused_conv_rewriter",
":cusolver_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@ -779,7 +835,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:cholesky_expander",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",
@ -809,7 +864,6 @@ cc_library(
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:while_loop_trip_count_annotator",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:cuda_libdevice_path",

View File

@ -0,0 +1,119 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/device_memory.h"
namespace xla {
namespace gpu {
CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
BufferAllocation::Slice a_buffer,
BufferAllocation::Slice workspace_buffer,
BufferAllocation::Slice info_buffer,
PrimitiveType type, int64 batch_size, int64 n,
const HloInstruction* hlo)
: Thunk(Kind::kCholesky, hlo),
uplo_(options.lower() ? se::blas::UpperLower::kLower
: se::blas::UpperLower::kUpper),
a_buffer_(a_buffer),
workspace_buffer_(workspace_buffer),
info_buffer_(info_buffer),
type_(type),
batch_size_(batch_size),
a_batch_stride_(n * n *
ShapeUtil::ByteSizeOfPrimitiveType(
hlo->operand(0)->shape().element_type())),
n_(n) {}
Status CholeskyThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
VLOG(3) << "type=" << PrimitiveType_Name(type_)
<< " uplo=" << se::blas::UpperLowerString(uplo_)
<< " batch_size=" << batch_size_ << " n=" << n_
<< " a=" << a_buffer_.ToString()
<< " workspace=" << workspace_buffer_.ToString()
<< " info=" << info_buffer_.ToString();
CusolverContext* context;
{
tensorflow::mutex_lock lock(mu_);
auto result = contexts_.emplace(stream, CusolverContext());
if (result.second) {
TF_ASSIGN_OR_RETURN(result.first->second,
CusolverContext::Create(stream));
}
context = &result.first->second;
}
char* a_base = static_cast<char*>(
buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
int* info_base = static_cast<int*>(
buffer_allocations.GetDeviceAddress(info_buffer_).opaque());
se::DeviceMemoryBase workspace_data =
buffer_allocations.GetDeviceAddress(workspace_buffer_);
for (int64 i = 0; i < batch_size_; ++i) {
se::DeviceMemoryBase a_data =
se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
se::DeviceMemory<int> info_data(
se::DeviceMemoryBase(info_base + i, sizeof(int)));
switch (type_) {
case F32: {
TF_RETURN_IF_ERROR(
context->Potrf(uplo_, n_, se::DeviceMemory<float>(a_data), n_,
info_data, se::DeviceMemory<float>(workspace_data)));
break;
}
case F64: {
TF_RETURN_IF_ERROR(context->Potrf(
uplo_, n_, se::DeviceMemory<double>(a_data), n_, info_data,
se::DeviceMemory<double>(workspace_data)));
break;
}
case C64: {
TF_RETURN_IF_ERROR(context->Potrf(
uplo_, n_, se::DeviceMemory<std::complex<float>>(a_data), n_,
info_data, se::DeviceMemory<std::complex<float>>(workspace_data)));
break;
}
case C128: {
TF_RETURN_IF_ERROR(context->Potrf(
uplo_, n_, se::DeviceMemory<std::complex<double>>(a_data), n_,
info_data, se::DeviceMemory<std::complex<double>>(workspace_data)));
break;
}
default:
return InvalidArgument("Invalid type for cholesky %s",
PrimitiveType_Name(type_));
}
}
return Status::OK();
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_
#include "absl/base/thread_annotations.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/blas.h"
namespace xla {
namespace gpu {
// This class stores everything that StreamExecutor needs to launch a Cholesky
// decomposition (LAPACK potrf). It is generated by IrEmitter.
//
// Thread-compatible.
class CholeskyThunk : public Thunk {
public:
static StatusOr<int64> ScratchBufferSize(int64 n);
CholeskyThunk(const CholeskyOptions& options,
BufferAllocation::Slice a_buffer,
BufferAllocation::Slice workspace_buffer,
BufferAllocation::Slice info_buffer,
PrimitiveType type,
int64 batch_size, int64 n, const HloInstruction* hlo);
CholeskyThunk(const CholeskyThunk&) = delete;
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) override;
private:
se::blas::UpperLower uplo_;
const BufferAllocation::Slice a_buffer_;
const BufferAllocation::Slice workspace_buffer_;
const BufferAllocation::Slice info_buffer_;
const PrimitiveType type_;
const int64 batch_size_;
const int64 a_batch_stride_;
const int64 n_;
tensorflow::mutex mu_;
absl::flat_hash_map<se::Stream*, CusolverContext> contexts_ GUARDED_BY(mu_);
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logger.h"
@ -37,47 +38,6 @@ using absl::optional;
using se::DeviceMemoryBase;
using se::dnn::AlgorithmDesc;
class ScratchAllocator : public se::ScratchAllocator {
public:
ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
}
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
int64 byte_size) override;
private:
const int device_ordinal_;
DeviceMemoryAllocator* memory_allocator_;
std::vector<OwningDeviceMemory> allocated_buffers_;
int64 total_allocated_bytes_ = 0;
};
StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
se::Stream* stream, int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
absl::StrFormat(
"Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
memory_allocator_->Allocate(device_ordinal_, byte_size,
/*retry_on_failure=*/false));
total_allocated_bytes_ += byte_size;
se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
allocated_buffers_.push_back(std::move(allocated_buffer));
return se::DeviceMemory<uint8>(buffer_addr);
}
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;

View File

@ -0,0 +1,159 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
namespace {
// Type traits to get CUDA complex types from std::complex<T>.
template <typename T>
struct CUDAComplexT {
typedef T type;
};
template <>
struct CUDAComplexT<std::complex<float>> {
typedef cuComplex type;
};
template <>
struct CUDAComplexT<std::complex<double>> {
typedef cuDoubleComplex type;
};
template <typename T>
inline typename CUDAComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
return static_cast<typename CUDAComplexT<T>::type*>(p.opaque());
}
cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) {
switch (uplo) {
case se::blas::UpperLower::kUpper:
return CUBLAS_FILL_MODE_UPPER;
case se::blas::UpperLower::kLower:
return CUBLAS_FILL_MODE_LOWER;
default:
LOG(FATAL) << "Invalid value of blas::UpperLower.";
}
}
// Converts a cuSolver status to a Status.
Status CusolverStatusToStatus(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return Status::OK();
case CUSOLVER_STATUS_NOT_INITIALIZED:
return FailedPrecondition("cuSolver has not been initialized");
case CUSOLVER_STATUS_ALLOC_FAILED:
return ResourceExhausted("cuSolver allocation failed");
case CUSOLVER_STATUS_INVALID_VALUE:
return InvalidArgument("cuSolver invalid value error");
case CUSOLVER_STATUS_ARCH_MISMATCH:
return FailedPrecondition("cuSolver architecture mismatch error");
case CUSOLVER_STATUS_MAPPING_ERROR:
return Unknown("cuSolver mapping error");
case CUSOLVER_STATUS_EXECUTION_FAILED:
return Unknown("cuSolver execution failed");
case CUSOLVER_STATUS_INTERNAL_ERROR:
return Internal("cuSolver internal error");
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return Unimplemented("cuSolver matrix type not supported error");
case CUSOLVER_STATUS_NOT_SUPPORTED:
return Unimplemented("cuSolver not supported error");
case CUSOLVER_STATUS_ZERO_PIVOT:
return InvalidArgument("cuSolver zero pivot error");
case CUSOLVER_STATUS_INVALID_LICENSE:
return FailedPrecondition("cuSolver invalid license error");
default:
return Unknown("Unknown cuSolver error");
}
}
} // namespace
StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
cusolverDnHandle_t handle;
TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle)));
CusolverContext context(stream, handle);
// StreamExecutor really should just expose the Cuda stream to clients...
const cudaStream_t* cuda_stream =
CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
stream->implementation()->GpuStreamMemberHack()));
TF_RETURN_IF_ERROR(
CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream)));
return std::move(context);
}
CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle)
: stream_(stream), handle_(handle) {}
CusolverContext::CusolverContext(CusolverContext&& other) {
handle_ = other.handle_;
stream_ = other.stream_;
other.handle_ = nullptr;
other.stream_ = nullptr;
}
CusolverContext& CusolverContext::operator=(CusolverContext&& other) {
std::swap(handle_, other.handle_);
std::swap(stream_, other.stream_);
return *this;
}
CusolverContext::~CusolverContext() {
if (handle_) {
Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_));
if (!status.ok()) {
LOG(ERROR) << "cusolverDnDestroy failed: " << status;
}
}
}
#define CALL_LAPACK_TYPES(m) \
m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
#define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
#define POTRF_BUFFER_SIZE_INSTANCE(T, type_prefix) \
StatusOr<int64> CusolverContext::PotrfBufferSize( \
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda) { \
int size = -1; \
TF_RETURN_IF_ERROR(CusolverStatusToStatus(DN_SOLVER_FN( \
potrf_bufferSize, type_prefix)(handle(), CUDABlasUpperLower(uplo), n, \
ToDevicePointer(A), lda, &size))); \
return size; \
}
CALL_LAPACK_TYPES(POTRF_BUFFER_SIZE_INSTANCE);
#define POTRF_INSTANCE(T, type_prefix) \
Status CusolverContext::Potrf( \
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
ToDevicePointer(workspace), workspace.ElementCount(), \
ToDevicePointer(lapack_info))); \
}
CALL_LAPACK_TYPES(POTRF_INSTANCE);
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,88 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
#include <complex>
#include "cuda/include/cublas_v2.h"
#include "cuda/include/cusolverDn.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/blas.h"
namespace xla {
namespace gpu {
class CusolverContext {
public:
static StatusOr<CusolverContext> Create(se::Stream* stream);
CusolverContext() = default;
~CusolverContext();
CusolverContext(const CusolverContext&) = delete;
CusolverContext(CusolverContext&&);
CusolverContext& operator=(const CusolverContext&) = delete;
CusolverContext& operator=(CusolverContext&&);
se::Stream* stream() const { return stream_; }
cusolverDnHandle_t handle() const { return handle_; }
// Computes the Cholesky factorization A = L * L^T for a single matrix.
// Returns Status::OK() if the kernel was launched successfully. See:
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<float> dev_A,
int lda, se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<float> workspace);
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<double> dev_A,
int lda, se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<double> workspace);
Status Potrf(se::blas::UpperLower uplo, int n,
se::DeviceMemory<std::complex<float>> dev_A, int lda,
se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<std::complex<float>> workspace);
Status Potrf(se::blas::UpperLower uplo, int n,
se::DeviceMemory<std::complex<double>> dev_A, int lda,
se::DeviceMemory<int> dev_lapack_info,
se::DeviceMemory<std::complex<double>> workspace);
// Returns the size of the `workspace` required by Potrf, in number of
// elements of size T.
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
se::DeviceMemory<float> dev_A, int lda);
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
se::DeviceMemory<double> dev_A, int lda);
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
se::DeviceMemory<std::complex<float>> dev_A,
int lda);
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
se::DeviceMemory<std::complex<double>> dev_A,
int lda);
private:
CusolverContext(se::Stream* stream, cusolverDnHandle_t handle);
se::Stream* stream_ = nullptr;
cusolverDnHandle_t handle_ = nullptr;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_

View File

@ -0,0 +1,216 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
#include <cstdlib>
#include <numeric>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/blas.h"
namespace xla {
namespace gpu {
namespace {
void SetFortranLayout(Shape* shape) {
LayoutUtil::SetToDefaultLayout(shape);
int n = shape->mutable_layout()->minor_to_major_size();
CHECK_GE(n, 2);
std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
shape->mutable_layout()->mutable_minor_to_major()->at(1));
}
StatusOr<HloInstruction*> CreateCholesky(CusolverContext* context,
ScratchAllocator* allocator,
HloInstruction* operand,
const CholeskyOptions& options,
const OpMetadata& metadata) {
HloComputation* computation = operand->parent();
Shape a_shape = operand->shape();
int ndim = a_shape.dimensions_size();
CHECK_GE(ndim, 2);
int64 n = a_shape.dimensions(ndim - 1);
int64 batch_size = std::accumulate(a_shape.dimensions().begin(),
a_shape.dimensions().end() - 2, int64{1},
[](int64 a, int64 b) { return a * b; });
// Find the workspace size.
se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
: se::blas::UpperLower::kUpper;
int64 workspace_size; // Number of elements of size a_shape.element_type()
switch (a_shape.element_type()) {
case F32: {
TF_ASSIGN_OR_RETURN(auto a,
allocator->Allocate<float>(context->stream(), n * n));
TF_ASSIGN_OR_RETURN(workspace_size,
context->PotrfBufferSize(uplo, n, a, n));
break;
}
case F64: {
TF_ASSIGN_OR_RETURN(
auto a, allocator->Allocate<double>(context->stream(), n * n));
TF_ASSIGN_OR_RETURN(workspace_size,
context->PotrfBufferSize(uplo, n, a, n));
break;
}
case C64: {
TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate<std::complex<float>>(
context->stream(), n * n));
TF_ASSIGN_OR_RETURN(workspace_size,
context->PotrfBufferSize(uplo, n, a, n));
break;
}
case C128: {
TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate<std::complex<double>>(
context->stream(), n * n));
TF_ASSIGN_OR_RETURN(workspace_size,
context->PotrfBufferSize(uplo, n, a, n));
break;
}
default:
return InvalidArgument("Invalid type for cholesky decomposition: %s",
a_shape.ToString());
}
// TODO(phawkins): Ideally we would relax this constraint. What we actually
// want is that:
// a) the batch dimensions are major, in no particular order.
// b) the two minor dimensions are in fortran (column-major) order,
SetFortranLayout(&a_shape);
// This call returns a tuple of (cholesky_result, workspace, info) where:
// * cholesky_result is the result of the Cholesky decomposition,
// * workspace is temporary scratch memory used by cuSolver.
// * info contains the Potrf success/failure status.
// Currently we have no meaningful way to report an error, so we simply
// discard the success/failure information. Obviously this is suboptimal.
Shape call_shape = ShapeUtil::MakeTupleShape(
{a_shape,
ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
ShapeUtil::MakeShape(S32, {batch_size})});
HloInstruction* custom_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
custom_call->set_metadata(metadata);
TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
return custom_call;
}
} // namespace
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(CusolverContext* context,
ScratchAllocator* allocator,
HloInstruction* instruction) {
if (instruction->opcode() != HloOpcode::kCholesky) {
return false;
}
TF_ASSIGN_OR_RETURN(
HloInstruction * custom_call,
CreateCholesky(context, allocator, instruction->mutable_operand(0),
instruction->cholesky_options(), instruction->metadata()));
VLOG(1) << "Replacing " << instruction->ToString() << " with "
<< custom_call->ToString();
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(instruction->parent()->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateGetTupleElement(instruction->shape(),
custom_call, 0)));
return true;
}
// Rewrites the convolutions in the given computation into calls to cudnn.
// Returns true if it made any changes.
StatusOr<bool> CusolverRewriter::RunOnComputation(HloComputation* computation) {
std::vector<HloInstruction*> cusolver_calls;
for (auto* hlo : computation->instructions()) {
if (hlo->opcode() == HloOpcode::kCholesky) {
cusolver_calls.push_back(hlo);
}
}
if (cusolver_calls.empty()) {
return false;
}
// Create a stream for us to do our work on. We don't really need to do any
// work, just allocate memory, but that's the cuSolver API.
se::Stream stream{stream_exec_};
stream.Init();
const auto device_ordinal = stream_exec_->device_ordinal();
// allocator either points to this->allocator_ or, if that's null, to a
// StreamExecutorMemoryAllocator for stream_exec_.
DeviceMemoryAllocator* allocator;
absl::optional<StreamExecutorMemoryAllocator> se_allocator;
if (allocator_ != nullptr) {
allocator = allocator_;
} else {
se_allocator.emplace(stream_exec_->platform(),
absl::Span<se::StreamExecutor* const>({stream_exec_}));
allocator = &*se_allocator;
}
ScratchAllocator scratch_allocator(device_ordinal, allocator);
TF_ASSIGN_OR_RETURN(CusolverContext context,
CusolverContext::Create(&stream));
bool changed = false;
for (HloInstruction* instruction : cusolver_calls) {
TF_ASSIGN_OR_RETURN(
bool result,
RunOnInstruction(&context, &scratch_allocator, instruction));
changed |= result;
}
return changed;
}
CusolverRewriter::CusolverRewriter(se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* allocator)
: stream_exec_(stream_exec), allocator_(allocator) {}
StatusOr<bool> CusolverRewriter::Run(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
changed |= result;
}
return changed;
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver.
class CusolverRewriter : public HloModulePass {
public:
CusolverRewriter(se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* allocator);
absl::string_view name() const override { return "cusolver-rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_

View File

@ -142,6 +142,16 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
target == kCudnnConvBiasActivationForwardCallTarget;
}
const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
bool IsCustomCallToCusolver(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCusolverCholeskyCallTarget;
}
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
IsCustomCallToDnnConvolution(hlo);

View File

@ -131,6 +131,19 @@ extern const char* const kCudnnConvBiasActivationForwardCallTarget;
// kConvolution opcode.
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
// Returns true if `hlo` will be implemented as a call to a cuSolver routine.
//
// This returns true if `hlo` is a CustomCall HLO with a call target equal to
// one of the kCusolver... constants, but returns *false* for HLOs with
// say, a kCholesky opcode.
bool IsCustomCallToCusolver(const HloInstruction& hlo);
// Cholesky decomposition. Takes a (batched) matrix as input, and returns a
// tuple of (result, workspace, info), where result is the result of the
// Cholesky decomposition, workspace is scratch space for cuSolver, and info
// is a success/failure code per batch element.
extern const char* const kCusolverCholeskyCallTarget;
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
bool ImplementedAsLibraryCall(const HloInstruction& hlo);

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
@ -480,6 +481,51 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
return Status::OK();
}
if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
TF_ASSIGN_OR_RETURN(CholeskyOptions options,
custom_call->backend_config<CholeskyOptions>());
const Shape& shape = custom_call->operand(0)->shape();
int ndim = shape.dimensions_size();
CHECK_GE(ndim, 2);
int64 n = shape.dimensions(ndim - 1);
const auto& dims = shape.dimensions();
int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
[](int64 a, int64 b) { return a * b; });
auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
const auto& assn = ir_emitter_context_->buffer_assignment();
auto a_buffer = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto workspace_buffer = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto info_buffer = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
std::vector<std::unique_ptr<Thunk>> thunks;
if (operand_buffer != a_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/operand_buffer,
/*destination_buffer=*/a_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
}
thunks.push_back(absl::make_unique<CholeskyThunk>(
options, a_buffer, workspace_buffer, info_buffer,
custom_call->operand(0)->shape().element_type(), batch_size, n,
custom_call));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
}
return Status::OK();
}
return IrEmitter::HandleCustomCall(custom_call);
}

View File

@ -320,6 +320,9 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a FftThunk that calls cuFFT to implement `inst`.
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
// Returns a CholeskyThunk that calls cuSolver to implement `inst`.
std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
// Returns a TriangularSolveThunk that calls cuBlas to implement `inst`.
std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst);

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
@ -47,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@ -188,8 +188,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
pipeline.AddPass<CholeskyExpander>();
// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();
auto cost_model = [](HloInstruction* conv) {
@ -267,10 +265,11 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (CudnnConvPaddingLegalization).
// (CudnnConvPaddingLegalization). Also expand cuSolver calls.
HloPassPipeline pipeline("conv_canonicalization");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CusolverRewriter>(stream_exec, device_allocator);
pipeline.AddPass<CudnnConvRewriter>();
pipeline.AddPass<CudnnFusedConvRewriter>();
pipeline.AddPass<CudnnConvPaddingLegalization>();
@ -343,6 +342,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// wouldn't be able to simplify away the new_tuple bits.
pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator,
compiler);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();

View File

@ -0,0 +1,43 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
namespace xla {
namespace gpu {
StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
se::Stream* stream, int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
absl::StrFormat(
"Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
memory_allocator_->Allocate(device_ordinal_, byte_size,
/*retry_on_failure=*/false));
total_allocated_bytes_ += byte_size;
se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
allocated_buffers_.push_back(std::move(allocated_buffer));
return se::DeviceMemory<uint8>(buffer_addr);
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,61 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
#include <vector>
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
class ScratchAllocator : public se::ScratchAllocator {
public:
ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
}
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
int64 byte_size) override;
template <typename T>
StatusOr<se::DeviceMemory<T>> Allocate(se::Stream* stream,
int64 num_elements) {
TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
AllocateBytes(stream, num_elements * sizeof(T)));
return se::DeviceMemory<T>(bytes);
}
private:
const int device_ordinal_;
DeviceMemoryAllocator* memory_allocator_;
std::vector<OwningDeviceMemory> allocated_buffers_;
int64 total_allocated_bytes_ = 0;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_

View File

@ -20,6 +20,8 @@ namespace gpu {
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
switch (kind) {
case Thunk::kCholesky:
return os << "kCholesky";
case Thunk::kConditional:
return os << "kConditional";
case Thunk::kConvolution:

View File

@ -42,6 +42,7 @@ class GpuExecutable;
class Thunk {
public:
enum Kind {
kCholesky,
kConditional,
kConvolution,
kCopy,

View File

@ -260,6 +260,16 @@ Status Unavailable(const absl::FormatSpec<Args...>& format,
return WithLogBacktrace(
tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
}
template <typename... Args>
Status Unknown(const absl::FormatSpec<Args...>& format, const Args&... args) {
return WithLogBacktrace(
tensorflow::errors::Unknown(absl::StrFormat(format, args...)));
}
template <typename... Args>
Status Internal(const absl::FormatSpec<Args...>& format, const Args&... args) {
return WithLogBacktrace(
tensorflow::errors::Internal(absl::StrFormat(format, args...)));
}
template <typename... Args>
Status InvalidArgumentStrCat(Args&&... concat) {