[XLA:GPU] Add implementation of Cholesky that calls into cuSolver.
PiperOrigin-RevId: 236123818
This commit is contained in:
parent
fdbaab6f50
commit
f38eea2aec
@ -326,6 +326,7 @@ tf_cuda_library(
|
||||
cc_library(
|
||||
name = "gpu_executable",
|
||||
srcs = [
|
||||
"cholesky_thunk.cc",
|
||||
"conditional_thunk.cc",
|
||||
"convolution_thunk.cc",
|
||||
"copy_thunk.cc",
|
||||
@ -345,6 +346,7 @@ cc_library(
|
||||
"while_thunk.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cholesky_thunk.h",
|
||||
"conditional_thunk.h",
|
||||
"convolution_thunk.h",
|
||||
"copy_thunk.h",
|
||||
@ -366,6 +368,7 @@ cc_library(
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":cudnn_conv_runner",
|
||||
":cusolver_context",
|
||||
":hlo_execution_profiler",
|
||||
":infeed_manager",
|
||||
":ir_emission_utils",
|
||||
@ -404,6 +407,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor:blas",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -442,6 +446,7 @@ cc_library(
|
||||
":cudnn_conv_runner",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
":scratch_allocator",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
@ -459,6 +464,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scratch_allocator",
|
||||
srcs = ["scratch_allocator.cc"],
|
||||
hdrs = ["scratch_allocator.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_conv_runner",
|
||||
srcs = ["cudnn_conv_runner.cc"],
|
||||
@ -515,6 +532,43 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_context",
|
||||
srcs = ["cusolver_context.cc"],
|
||||
hdrs = ["cusolver_context.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:blas",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cusolver",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_rewriter",
|
||||
srcs = ["cusolver_rewriter.cc"],
|
||||
hdrs = ["cusolver_rewriter.h"],
|
||||
deps = [
|
||||
":cusolver_context",
|
||||
":ir_emission_utils",
|
||||
":scratch_allocator",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:blas",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "instruction_fusion",
|
||||
srcs = ["instruction_fusion.cc"],
|
||||
@ -748,11 +802,13 @@ cc_library(
|
||||
srcs = ["nvptx_compiler.cc"],
|
||||
hdrs = ["nvptx_compiler.h"],
|
||||
deps = [
|
||||
":cudnn_batchnorm_rewriter",
|
||||
":cudnn_conv_algorithm_picker",
|
||||
":cudnn_conv_pad_for_tensor_cores",
|
||||
":cudnn_conv_padding_legalization",
|
||||
":cudnn_conv_rewriter",
|
||||
":cudnn_fused_conv_rewriter",
|
||||
":cusolver_rewriter",
|
||||
":fusion_merger",
|
||||
":gpu_constants",
|
||||
":gpu_copy_insertion",
|
||||
@ -779,7 +835,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:cholesky_expander",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
@ -809,7 +864,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
"//tensorflow/compiler/xla/service:while_loop_trip_count_annotator",
|
||||
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
|
||||
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:cuda_libdevice_path",
|
||||
|
119
tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
Normal file
119
tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
CholeskyThunk::CholeskyThunk(const CholeskyOptions& options,
|
||||
BufferAllocation::Slice a_buffer,
|
||||
BufferAllocation::Slice workspace_buffer,
|
||||
BufferAllocation::Slice info_buffer,
|
||||
PrimitiveType type, int64 batch_size, int64 n,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kCholesky, hlo),
|
||||
uplo_(options.lower() ? se::blas::UpperLower::kLower
|
||||
: se::blas::UpperLower::kUpper),
|
||||
a_buffer_(a_buffer),
|
||||
workspace_buffer_(workspace_buffer),
|
||||
info_buffer_(info_buffer),
|
||||
type_(type),
|
||||
batch_size_(batch_size),
|
||||
a_batch_stride_(n * n *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
hlo->operand(0)->shape().element_type())),
|
||||
n_(n) {}
|
||||
|
||||
Status CholeskyThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
VLOG(3) << "type=" << PrimitiveType_Name(type_)
|
||||
<< " uplo=" << se::blas::UpperLowerString(uplo_)
|
||||
<< " batch_size=" << batch_size_ << " n=" << n_
|
||||
<< " a=" << a_buffer_.ToString()
|
||||
<< " workspace=" << workspace_buffer_.ToString()
|
||||
<< " info=" << info_buffer_.ToString();
|
||||
|
||||
CusolverContext* context;
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
auto result = contexts_.emplace(stream, CusolverContext());
|
||||
if (result.second) {
|
||||
TF_ASSIGN_OR_RETURN(result.first->second,
|
||||
CusolverContext::Create(stream));
|
||||
}
|
||||
context = &result.first->second;
|
||||
}
|
||||
|
||||
char* a_base = static_cast<char*>(
|
||||
buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
|
||||
int* info_base = static_cast<int*>(
|
||||
buffer_allocations.GetDeviceAddress(info_buffer_).opaque());
|
||||
se::DeviceMemoryBase workspace_data =
|
||||
buffer_allocations.GetDeviceAddress(workspace_buffer_);
|
||||
for (int64 i = 0; i < batch_size_; ++i) {
|
||||
se::DeviceMemoryBase a_data =
|
||||
se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
|
||||
se::DeviceMemory<int> info_data(
|
||||
se::DeviceMemoryBase(info_base + i, sizeof(int)));
|
||||
switch (type_) {
|
||||
case F32: {
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->Potrf(uplo_, n_, se::DeviceMemory<float>(a_data), n_,
|
||||
info_data, se::DeviceMemory<float>(workspace_data)));
|
||||
break;
|
||||
}
|
||||
case F64: {
|
||||
TF_RETURN_IF_ERROR(context->Potrf(
|
||||
uplo_, n_, se::DeviceMemory<double>(a_data), n_, info_data,
|
||||
se::DeviceMemory<double>(workspace_data)));
|
||||
break;
|
||||
}
|
||||
case C64: {
|
||||
TF_RETURN_IF_ERROR(context->Potrf(
|
||||
uplo_, n_, se::DeviceMemory<std::complex<float>>(a_data), n_,
|
||||
info_data, se::DeviceMemory<std::complex<float>>(workspace_data)));
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
TF_RETURN_IF_ERROR(context->Potrf(
|
||||
uplo_, n_, se::DeviceMemory<std::complex<double>>(a_data), n_,
|
||||
info_data, se::DeviceMemory<std::complex<double>>(workspace_data)));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return InvalidArgument("Invalid type for cholesky %s",
|
||||
PrimitiveType_Name(type_));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
77
tensorflow/compiler/xla/service/gpu/cholesky_thunk.h
Normal file
77
tensorflow/compiler/xla/service/gpu/cholesky_thunk.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// This class stores everything that StreamExecutor needs to launch a Cholesky
|
||||
// decomposition (LAPACK potrf). It is generated by IrEmitter.
|
||||
//
|
||||
// Thread-compatible.
|
||||
class CholeskyThunk : public Thunk {
|
||||
public:
|
||||
static StatusOr<int64> ScratchBufferSize(int64 n);
|
||||
CholeskyThunk(const CholeskyOptions& options,
|
||||
BufferAllocation::Slice a_buffer,
|
||||
BufferAllocation::Slice workspace_buffer,
|
||||
BufferAllocation::Slice info_buffer,
|
||||
PrimitiveType type,
|
||||
int64 batch_size, int64 n, const HloInstruction* hlo);
|
||||
|
||||
CholeskyThunk(const CholeskyThunk&) = delete;
|
||||
CholeskyThunk& operator=(const CholeskyThunk&) = delete;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) override;
|
||||
|
||||
private:
|
||||
se::blas::UpperLower uplo_;
|
||||
|
||||
const BufferAllocation::Slice a_buffer_;
|
||||
const BufferAllocation::Slice workspace_buffer_;
|
||||
const BufferAllocation::Slice info_buffer_;
|
||||
|
||||
const PrimitiveType type_;
|
||||
const int64 batch_size_;
|
||||
const int64 a_batch_stride_;
|
||||
const int64 n_;
|
||||
|
||||
tensorflow::mutex mu_;
|
||||
absl::flat_hash_map<se::Stream*, CusolverContext> contexts_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CHOLESKY_THUNK_H_
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
@ -37,47 +38,6 @@ using absl::optional;
|
||||
using se::DeviceMemoryBase;
|
||||
using se::dnn::AlgorithmDesc;
|
||||
|
||||
class ScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
|
||||
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
|
||||
|
||||
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
|
||||
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
|
||||
}
|
||||
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
|
||||
|
||||
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
|
||||
int64 byte_size) override;
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
DeviceMemoryAllocator* memory_allocator_;
|
||||
std::vector<OwningDeviceMemory> allocated_buffers_;
|
||||
int64 total_allocated_bytes_ = 0;
|
||||
};
|
||||
|
||||
StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
|
||||
se::Stream* stream, int64 byte_size) {
|
||||
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
|
||||
if (byte_size > GetMemoryLimitInBytes(stream)) {
|
||||
return se::port::Status(
|
||||
se::port::error::RESOURCE_EXHAUSTED,
|
||||
absl::StrFormat(
|
||||
"Allocating %d bytes exceeds the memory limit of %d bytes.",
|
||||
byte_size, GetMemoryLimitInBytes(stream)));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
|
||||
memory_allocator_->Allocate(device_ordinal_, byte_size,
|
||||
/*retry_on_failure=*/false));
|
||||
total_allocated_bytes_ += byte_size;
|
||||
|
||||
se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
|
||||
allocated_buffers_.push_back(std::move(allocated_buffer));
|
||||
return se::DeviceMemory<uint8>(buffer_addr);
|
||||
}
|
||||
|
||||
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
std::vector<AlgorithmDesc> algorithms;
|
||||
|
159
tensorflow/compiler/xla/service/gpu/cusolver_context.cc
Normal file
159
tensorflow/compiler/xla/service/gpu/cusolver_context.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Type traits to get CUDA complex types from std::complex<T>.
|
||||
template <typename T>
|
||||
struct CUDAComplexT {
|
||||
typedef T type;
|
||||
};
|
||||
template <>
|
||||
struct CUDAComplexT<std::complex<float>> {
|
||||
typedef cuComplex type;
|
||||
};
|
||||
template <>
|
||||
struct CUDAComplexT<std::complex<double>> {
|
||||
typedef cuDoubleComplex type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename CUDAComplexT<T>::type* ToDevicePointer(se::DeviceMemory<T> p) {
|
||||
return static_cast<typename CUDAComplexT<T>::type*>(p.opaque());
|
||||
}
|
||||
|
||||
cublasFillMode_t CUDABlasUpperLower(se::blas::UpperLower uplo) {
|
||||
switch (uplo) {
|
||||
case se::blas::UpperLower::kUpper:
|
||||
return CUBLAS_FILL_MODE_UPPER;
|
||||
case se::blas::UpperLower::kLower:
|
||||
return CUBLAS_FILL_MODE_LOWER;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid value of blas::UpperLower.";
|
||||
}
|
||||
}
|
||||
|
||||
// Converts a cuSolver status to a Status.
|
||||
Status CusolverStatusToStatus(cusolverStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSOLVER_STATUS_SUCCESS:
|
||||
return Status::OK();
|
||||
case CUSOLVER_STATUS_NOT_INITIALIZED:
|
||||
return FailedPrecondition("cuSolver has not been initialized");
|
||||
case CUSOLVER_STATUS_ALLOC_FAILED:
|
||||
return ResourceExhausted("cuSolver allocation failed");
|
||||
case CUSOLVER_STATUS_INVALID_VALUE:
|
||||
return InvalidArgument("cuSolver invalid value error");
|
||||
case CUSOLVER_STATUS_ARCH_MISMATCH:
|
||||
return FailedPrecondition("cuSolver architecture mismatch error");
|
||||
case CUSOLVER_STATUS_MAPPING_ERROR:
|
||||
return Unknown("cuSolver mapping error");
|
||||
case CUSOLVER_STATUS_EXECUTION_FAILED:
|
||||
return Unknown("cuSolver execution failed");
|
||||
case CUSOLVER_STATUS_INTERNAL_ERROR:
|
||||
return Internal("cuSolver internal error");
|
||||
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
return Unimplemented("cuSolver matrix type not supported error");
|
||||
case CUSOLVER_STATUS_NOT_SUPPORTED:
|
||||
return Unimplemented("cuSolver not supported error");
|
||||
case CUSOLVER_STATUS_ZERO_PIVOT:
|
||||
return InvalidArgument("cuSolver zero pivot error");
|
||||
case CUSOLVER_STATUS_INVALID_LICENSE:
|
||||
return FailedPrecondition("cuSolver invalid license error");
|
||||
default:
|
||||
return Unknown("Unknown cuSolver error");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<CusolverContext> CusolverContext::Create(se::Stream* stream) {
|
||||
cusolverDnHandle_t handle;
|
||||
TF_RETURN_IF_ERROR(CusolverStatusToStatus(cusolverDnCreate(&handle)));
|
||||
CusolverContext context(stream, handle);
|
||||
|
||||
// StreamExecutor really should just expose the Cuda stream to clients...
|
||||
const cudaStream_t* cuda_stream =
|
||||
CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(
|
||||
stream->implementation()->GpuStreamMemberHack()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CusolverStatusToStatus(cusolverDnSetStream(handle, *cuda_stream)));
|
||||
|
||||
return std::move(context);
|
||||
}
|
||||
|
||||
CusolverContext::CusolverContext(se::Stream* stream, cusolverDnHandle_t handle)
|
||||
: stream_(stream), handle_(handle) {}
|
||||
|
||||
CusolverContext::CusolverContext(CusolverContext&& other) {
|
||||
handle_ = other.handle_;
|
||||
stream_ = other.stream_;
|
||||
other.handle_ = nullptr;
|
||||
other.stream_ = nullptr;
|
||||
}
|
||||
|
||||
CusolverContext& CusolverContext::operator=(CusolverContext&& other) {
|
||||
std::swap(handle_, other.handle_);
|
||||
std::swap(stream_, other.stream_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
CusolverContext::~CusolverContext() {
|
||||
if (handle_) {
|
||||
Status status = CusolverStatusToStatus(cusolverDnDestroy(handle_));
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "cusolverDnDestroy failed: " << status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_LAPACK_TYPES(m) \
|
||||
m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
|
||||
|
||||
#define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
|
||||
|
||||
#define POTRF_BUFFER_SIZE_INSTANCE(T, type_prefix) \
|
||||
StatusOr<int64> CusolverContext::PotrfBufferSize( \
|
||||
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda) { \
|
||||
int size = -1; \
|
||||
TF_RETURN_IF_ERROR(CusolverStatusToStatus(DN_SOLVER_FN( \
|
||||
potrf_bufferSize, type_prefix)(handle(), CUDABlasUpperLower(uplo), n, \
|
||||
ToDevicePointer(A), lda, &size))); \
|
||||
return size; \
|
||||
}
|
||||
|
||||
CALL_LAPACK_TYPES(POTRF_BUFFER_SIZE_INSTANCE);
|
||||
|
||||
#define POTRF_INSTANCE(T, type_prefix) \
|
||||
Status CusolverContext::Potrf( \
|
||||
se::blas::UpperLower uplo, int n, se::DeviceMemory<T> A, int lda, \
|
||||
se::DeviceMemory<int> lapack_info, se::DeviceMemory<T> workspace) { \
|
||||
return CusolverStatusToStatus(DN_SOLVER_FN(potrf, type_prefix)( \
|
||||
handle(), CUDABlasUpperLower(uplo), n, ToDevicePointer(A), lda, \
|
||||
ToDevicePointer(workspace), workspace.ElementCount(), \
|
||||
ToDevicePointer(lapack_info))); \
|
||||
}
|
||||
|
||||
CALL_LAPACK_TYPES(POTRF_INSTANCE);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
88
tensorflow/compiler/xla/service/gpu/cusolver_context.h
Normal file
88
tensorflow/compiler/xla/service/gpu/cusolver_context.h
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "cuda/include/cublas_v2.h"
|
||||
#include "cuda/include/cusolverDn.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class CusolverContext {
|
||||
public:
|
||||
static StatusOr<CusolverContext> Create(se::Stream* stream);
|
||||
CusolverContext() = default;
|
||||
~CusolverContext();
|
||||
|
||||
CusolverContext(const CusolverContext&) = delete;
|
||||
CusolverContext(CusolverContext&&);
|
||||
CusolverContext& operator=(const CusolverContext&) = delete;
|
||||
CusolverContext& operator=(CusolverContext&&);
|
||||
|
||||
se::Stream* stream() const { return stream_; }
|
||||
cusolverDnHandle_t handle() const { return handle_; }
|
||||
|
||||
// Computes the Cholesky factorization A = L * L^T for a single matrix.
|
||||
// Returns Status::OK() if the kernel was launched successfully. See:
|
||||
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
|
||||
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<float> dev_A,
|
||||
int lda, se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<float> workspace);
|
||||
Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<double> dev_A,
|
||||
int lda, se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<double> workspace);
|
||||
Status Potrf(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<std::complex<float>> dev_A, int lda,
|
||||
se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<std::complex<float>> workspace);
|
||||
Status Potrf(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<std::complex<double>> dev_A, int lda,
|
||||
se::DeviceMemory<int> dev_lapack_info,
|
||||
se::DeviceMemory<std::complex<double>> workspace);
|
||||
|
||||
// Returns the size of the `workspace` required by Potrf, in number of
|
||||
// elements of size T.
|
||||
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<float> dev_A, int lda);
|
||||
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<double> dev_A, int lda);
|
||||
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<std::complex<float>> dev_A,
|
||||
int lda);
|
||||
StatusOr<int64> PotrfBufferSize(se::blas::UpperLower uplo, int n,
|
||||
se::DeviceMemory<std::complex<double>> dev_A,
|
||||
int lda);
|
||||
|
||||
private:
|
||||
CusolverContext(se::Stream* stream, cusolverDnHandle_t handle);
|
||||
|
||||
se::Stream* stream_ = nullptr;
|
||||
cusolverDnHandle_t handle_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
|
216
tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
Normal file
216
tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc
Normal file
@ -0,0 +1,216 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
void SetFortranLayout(Shape* shape) {
|
||||
LayoutUtil::SetToDefaultLayout(shape);
|
||||
int n = shape->mutable_layout()->minor_to_major_size();
|
||||
CHECK_GE(n, 2);
|
||||
std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
|
||||
shape->mutable_layout()->mutable_minor_to_major()->at(1));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> CreateCholesky(CusolverContext* context,
|
||||
ScratchAllocator* allocator,
|
||||
HloInstruction* operand,
|
||||
const CholeskyOptions& options,
|
||||
const OpMetadata& metadata) {
|
||||
HloComputation* computation = operand->parent();
|
||||
|
||||
Shape a_shape = operand->shape();
|
||||
int ndim = a_shape.dimensions_size();
|
||||
CHECK_GE(ndim, 2);
|
||||
int64 n = a_shape.dimensions(ndim - 1);
|
||||
|
||||
int64 batch_size = std::accumulate(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().end() - 2, int64{1},
|
||||
[](int64 a, int64 b) { return a * b; });
|
||||
|
||||
// Find the workspace size.
|
||||
se::blas::UpperLower uplo = options.lower() ? se::blas::UpperLower::kLower
|
||||
: se::blas::UpperLower::kUpper;
|
||||
int64 workspace_size; // Number of elements of size a_shape.element_type()
|
||||
switch (a_shape.element_type()) {
|
||||
case F32: {
|
||||
TF_ASSIGN_OR_RETURN(auto a,
|
||||
allocator->Allocate<float>(context->stream(), n * n));
|
||||
TF_ASSIGN_OR_RETURN(workspace_size,
|
||||
context->PotrfBufferSize(uplo, n, a, n));
|
||||
break;
|
||||
}
|
||||
case F64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto a, allocator->Allocate<double>(context->stream(), n * n));
|
||||
TF_ASSIGN_OR_RETURN(workspace_size,
|
||||
context->PotrfBufferSize(uplo, n, a, n));
|
||||
break;
|
||||
}
|
||||
case C64: {
|
||||
TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate<std::complex<float>>(
|
||||
context->stream(), n * n));
|
||||
TF_ASSIGN_OR_RETURN(workspace_size,
|
||||
context->PotrfBufferSize(uplo, n, a, n));
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
TF_ASSIGN_OR_RETURN(auto a, allocator->Allocate<std::complex<double>>(
|
||||
context->stream(), n * n));
|
||||
TF_ASSIGN_OR_RETURN(workspace_size,
|
||||
context->PotrfBufferSize(uplo, n, a, n));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return InvalidArgument("Invalid type for cholesky decomposition: %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
|
||||
// TODO(phawkins): Ideally we would relax this constraint. What we actually
|
||||
// want is that:
|
||||
// a) the batch dimensions are major, in no particular order.
|
||||
// b) the two minor dimensions are in fortran (column-major) order,
|
||||
|
||||
SetFortranLayout(&a_shape);
|
||||
|
||||
// This call returns a tuple of (cholesky_result, workspace, info) where:
|
||||
// * cholesky_result is the result of the Cholesky decomposition,
|
||||
// * workspace is temporary scratch memory used by cuSolver.
|
||||
// * info contains the Potrf success/failure status.
|
||||
// Currently we have no meaningful way to report an error, so we simply
|
||||
// discard the success/failure information. Obviously this is suboptimal.
|
||||
Shape call_shape = ShapeUtil::MakeTupleShape(
|
||||
{a_shape,
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), {workspace_size}),
|
||||
ShapeUtil::MakeShape(S32, {batch_size})});
|
||||
|
||||
HloInstruction* custom_call =
|
||||
computation->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
call_shape, {operand}, kCusolverCholeskyCallTarget, {a_shape}));
|
||||
custom_call->set_metadata(metadata);
|
||||
TF_RETURN_IF_ERROR(custom_call->set_backend_config(options));
|
||||
return custom_call;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Tries to rewrite a single convolution into a call to cudnn.
|
||||
StatusOr<bool> RunOnInstruction(CusolverContext* context,
|
||||
ScratchAllocator* allocator,
|
||||
HloInstruction* instruction) {
|
||||
if (instruction->opcode() != HloOpcode::kCholesky) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * custom_call,
|
||||
CreateCholesky(context, allocator, instruction->mutable_operand(0),
|
||||
instruction->cholesky_options(), instruction->metadata()));
|
||||
|
||||
VLOG(1) << "Replacing " << instruction->ToString() << " with "
|
||||
<< custom_call->ToString();
|
||||
|
||||
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
|
||||
// the conv result and replace `conv` with it.
|
||||
TF_RETURN_IF_ERROR(instruction->parent()->ReplaceWithNewInstruction(
|
||||
instruction, HloInstruction::CreateGetTupleElement(instruction->shape(),
|
||||
custom_call, 0)));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Rewrites the convolutions in the given computation into calls to cudnn.
|
||||
// Returns true if it made any changes.
|
||||
StatusOr<bool> CusolverRewriter::RunOnComputation(HloComputation* computation) {
|
||||
std::vector<HloInstruction*> cusolver_calls;
|
||||
for (auto* hlo : computation->instructions()) {
|
||||
if (hlo->opcode() == HloOpcode::kCholesky) {
|
||||
cusolver_calls.push_back(hlo);
|
||||
}
|
||||
}
|
||||
|
||||
if (cusolver_calls.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create a stream for us to do our work on. We don't really need to do any
|
||||
// work, just allocate memory, but that's the cuSolver API.
|
||||
se::Stream stream{stream_exec_};
|
||||
stream.Init();
|
||||
const auto device_ordinal = stream_exec_->device_ordinal();
|
||||
|
||||
// allocator either points to this->allocator_ or, if that's null, to a
|
||||
// StreamExecutorMemoryAllocator for stream_exec_.
|
||||
DeviceMemoryAllocator* allocator;
|
||||
absl::optional<StreamExecutorMemoryAllocator> se_allocator;
|
||||
if (allocator_ != nullptr) {
|
||||
allocator = allocator_;
|
||||
} else {
|
||||
se_allocator.emplace(stream_exec_->platform(),
|
||||
absl::Span<se::StreamExecutor* const>({stream_exec_}));
|
||||
allocator = &*se_allocator;
|
||||
}
|
||||
ScratchAllocator scratch_allocator(device_ordinal, allocator);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(CusolverContext context,
|
||||
CusolverContext::Create(&stream));
|
||||
|
||||
bool changed = false;
|
||||
for (HloInstruction* instruction : cusolver_calls) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
RunOnInstruction(&context, &scratch_allocator, instruction));
|
||||
changed |= result;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
CusolverRewriter::CusolverRewriter(se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* allocator)
|
||||
: stream_exec_(stream_exec), allocator_(allocator) {}
|
||||
|
||||
StatusOr<bool> CusolverRewriter::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
|
||||
changed |= result;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
48
tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h
Normal file
48
tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cusolver_context.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// Rewrites Cholesky calls into CustomCall HLOs that call into cuSolver.
|
||||
class CusolverRewriter : public HloModulePass {
|
||||
public:
|
||||
CusolverRewriter(se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* allocator);
|
||||
absl::string_view name() const override { return "cusolver-rewriter"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
StatusOr<bool> RunOnComputation(HloComputation* computation);
|
||||
|
||||
se::StreamExecutor* stream_exec_; // never null
|
||||
DeviceMemoryAllocator* allocator_; // may be null
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_
|
@ -142,6 +142,16 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
|
||||
target == kCudnnConvBiasActivationForwardCallTarget;
|
||||
}
|
||||
|
||||
const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
|
||||
|
||||
bool IsCustomCallToCusolver(const HloInstruction& hlo) {
|
||||
if (hlo.opcode() != HloOpcode::kCustomCall) {
|
||||
return false;
|
||||
}
|
||||
const auto& target = hlo.custom_call_target();
|
||||
return target == kCusolverCholeskyCallTarget;
|
||||
}
|
||||
|
||||
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
|
||||
return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
|
||||
IsCustomCallToDnnConvolution(hlo);
|
||||
|
@ -131,6 +131,19 @@ extern const char* const kCudnnConvBiasActivationForwardCallTarget;
|
||||
// kConvolution opcode.
|
||||
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
|
||||
|
||||
// Returns true if `hlo` will be implemented as a call to a cuSolver routine.
|
||||
//
|
||||
// This returns true if `hlo` is a CustomCall HLO with a call target equal to
|
||||
// one of the kCusolver... constants, but returns *false* for HLOs with
|
||||
// say, a kCholesky opcode.
|
||||
bool IsCustomCallToCusolver(const HloInstruction& hlo);
|
||||
|
||||
// Cholesky decomposition. Takes a (batched) matrix as input, and returns a
|
||||
// tuple of (result, workspace, info), where result is the result of the
|
||||
// Cholesky decomposition, workspace is scratch space for cuSolver, and info
|
||||
// is a success/failure code per batch element.
|
||||
extern const char* const kCusolverCholeskyCallTarget;
|
||||
|
||||
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
|
||||
// or cuDNN convolution.
|
||||
bool ImplementedAsLibraryCall(const HloInstruction& hlo);
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
@ -480,6 +481,51 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
|
||||
TF_ASSIGN_OR_RETURN(CholeskyOptions options,
|
||||
custom_call->backend_config<CholeskyOptions>());
|
||||
|
||||
const Shape& shape = custom_call->operand(0)->shape();
|
||||
int ndim = shape.dimensions_size();
|
||||
CHECK_GE(ndim, 2);
|
||||
int64 n = shape.dimensions(ndim - 1);
|
||||
|
||||
const auto& dims = shape.dimensions();
|
||||
int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
|
||||
[](int64 a, int64 b) { return a * b; });
|
||||
|
||||
auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
|
||||
|
||||
const auto& assn = ir_emitter_context_->buffer_assignment();
|
||||
auto a_buffer = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
|
||||
auto workspace_buffer = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
|
||||
auto info_buffer = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
|
||||
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
if (operand_buffer != a_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/a_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call));
|
||||
}
|
||||
|
||||
thunks.push_back(absl::make_unique<CholeskyThunk>(
|
||||
options, a_buffer, workspace_buffer, info_buffer,
|
||||
custom_call->operand(0)->shape().element_type(), batch_size, n,
|
||||
custom_call));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), custom_call));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return IrEmitter::HandleCustomCall(custom_call);
|
||||
}
|
||||
|
||||
|
@ -320,6 +320,9 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
// Returns a FftThunk that calls cuFFT to implement `inst`.
|
||||
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
|
||||
|
||||
// Returns a CholeskyThunk that calls cuSolver to implement `inst`.
|
||||
std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
|
||||
|
||||
// Returns a TriangularSolveThunk that calls cuBlas to implement `inst`.
|
||||
std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst);
|
||||
|
||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
@ -47,6 +46,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
|
||||
@ -188,8 +188,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
|
||||
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
auto cost_model = [](HloInstruction* conv) {
|
||||
@ -267,10 +265,11 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
|
||||
{
|
||||
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
|
||||
// (CudnnConvPaddingLegalization).
|
||||
// (CudnnConvPaddingLegalization). Also expand cuSolver calls.
|
||||
HloPassPipeline pipeline("conv_canonicalization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddPass<CusolverRewriter>(stream_exec, device_allocator);
|
||||
pipeline.AddPass<CudnnConvRewriter>();
|
||||
pipeline.AddPass<CudnnFusedConvRewriter>();
|
||||
pipeline.AddPass<CudnnConvPaddingLegalization>();
|
||||
@ -343,6 +342,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
// wouldn't be able to simplify away the new_tuple bits.
|
||||
pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator,
|
||||
compiler);
|
||||
|
||||
// Clean up new_tuple described above.
|
||||
pipeline.AddPass<TupleSimplifier>();
|
||||
|
||||
|
43
tensorflow/compiler/xla/service/gpu/scratch_allocator.cc
Normal file
43
tensorflow/compiler/xla/service/gpu/scratch_allocator.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
|
||||
se::Stream* stream, int64 byte_size) {
|
||||
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
|
||||
if (byte_size > GetMemoryLimitInBytes(stream)) {
|
||||
return se::port::Status(
|
||||
se::port::error::RESOURCE_EXHAUSTED,
|
||||
absl::StrFormat(
|
||||
"Allocating %d bytes exceeds the memory limit of %d bytes.",
|
||||
byte_size, GetMemoryLimitInBytes(stream)));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
|
||||
memory_allocator_->Allocate(device_ordinal_, byte_size,
|
||||
/*retry_on_failure=*/false));
|
||||
total_allocated_bytes_ += byte_size;
|
||||
|
||||
se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
|
||||
allocated_buffers_.push_back(std::move(allocated_buffer));
|
||||
return se::DeviceMemory<uint8>(buffer_addr);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
61
tensorflow/compiler/xla/service/gpu/scratch_allocator.h
Normal file
61
tensorflow/compiler/xla/service/gpu/scratch_allocator.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class ScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
|
||||
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
|
||||
|
||||
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
|
||||
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
|
||||
}
|
||||
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
|
||||
|
||||
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
|
||||
int64 byte_size) override;
|
||||
|
||||
template <typename T>
|
||||
StatusOr<se::DeviceMemory<T>> Allocate(se::Stream* stream,
|
||||
int64 num_elements) {
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
|
||||
AllocateBytes(stream, num_elements * sizeof(T)));
|
||||
return se::DeviceMemory<T>(bytes);
|
||||
}
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
DeviceMemoryAllocator* memory_allocator_;
|
||||
std::vector<OwningDeviceMemory> allocated_buffers_;
|
||||
int64 total_allocated_bytes_ = 0;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_
|
@ -20,6 +20,8 @@ namespace gpu {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
|
||||
switch (kind) {
|
||||
case Thunk::kCholesky:
|
||||
return os << "kCholesky";
|
||||
case Thunk::kConditional:
|
||||
return os << "kConditional";
|
||||
case Thunk::kConvolution:
|
||||
|
@ -42,6 +42,7 @@ class GpuExecutable;
|
||||
class Thunk {
|
||||
public:
|
||||
enum Kind {
|
||||
kCholesky,
|
||||
kConditional,
|
||||
kConvolution,
|
||||
kCopy,
|
||||
|
@ -260,6 +260,16 @@ Status Unavailable(const absl::FormatSpec<Args...>& format,
|
||||
return WithLogBacktrace(
|
||||
tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
|
||||
}
|
||||
template <typename... Args>
|
||||
Status Unknown(const absl::FormatSpec<Args...>& format, const Args&... args) {
|
||||
return WithLogBacktrace(
|
||||
tensorflow::errors::Unknown(absl::StrFormat(format, args...)));
|
||||
}
|
||||
template <typename... Args>
|
||||
Status Internal(const absl::FormatSpec<Args...>& format, const Args&... args) {
|
||||
return WithLogBacktrace(
|
||||
tensorflow::errors::Internal(absl::StrFormat(format, args...)));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
Status InvalidArgumentStrCat(Args&&... concat) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user