[SE] Remove Stream* argument from ScratchAllocator methods

The arguments are unused (apart from the RedzoneAllocator,
which is only used on a single stream), and are unnecessarily propagated
through the wrappings.

PiperOrigin-RevId: 261726698
This commit is contained in:
George Karpenkov 2019-08-05 11:17:46 -07:00 committed by TensorFlower Gardener
parent 49d23307d5
commit 386a8d7702
19 changed files with 102 additions and 129 deletions

View File

@ -143,10 +143,8 @@ StatusOr<bool> CheckRedzones(const se::cuda::RedzoneAllocator& allocator,
XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
2);
using RedzoneCheckStatus = se::cuda::RedzoneAllocator::RedzoneCheckStatus;
TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
allocator.CheckRedzones(stream));
allocator.CheckRedzones());
if (redzone_check.ok()) {
return true;
}
@ -253,8 +251,6 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
// Create a stream for us to do our work on.
se::Stream stream{stream_exec_};
stream.Init();
const auto device_ordinal = stream_exec_->device_ordinal();
// allocator either points to this->allocator_ or, if that's null, to a
// se::StreamExecutorMemoryAllocator for stream_exec_.
se::DeviceMemoryAllocator* allocator;
@ -278,18 +274,18 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
// Allocate space for the input, filter, and output of the convolution.
se::cuda::RedzoneAllocator input_output_allocator(
device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config));
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
std::vector<se::DeviceMemoryBase> operand_buffers;
for (const auto* operand : instr->operands()) {
TF_ASSIGN_OR_RETURN(auto buffer,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(operand->shape())));
ShapeUtil::ByteSizeOf(operand->shape())));
initialize_buffer(buffer);
operand_buffers.push_back(buffer);
}
TF_ASSIGN_OR_RETURN(auto result_buffer,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(result_shape)));
ShapeUtil::ByteSizeOf(result_shape)));
initialize_buffer(result_buffer);
TF_ASSIGN_OR_RETURN(auto backend_config,
@ -331,7 +327,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
}
se::cuda::RedzoneAllocator scratch_allocator(
device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config));
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
@ -431,7 +427,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
comparator.emplace(result_shape, hlo_module_config);
TF_ASSIGN_OR_RETURN(
reference_result_buffer,
input_output_allocator.AllocateBytes(&stream, result_buffer.size()));
input_output_allocator.AllocateBytes(result_buffer.size()));
stream.ThenMemcpy(&reference_result_buffer, result_buffer,
result_buffer.size());
first_algorithm = alg;

View File

@ -48,12 +48,10 @@ class ScratchBufAllocator : public se::ScratchAllocator {
~ScratchBufAllocator() override = default;
int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override {
return scratch_.size();
}
int64 GetMemoryLimitInBytes() override { return scratch_.size(); }
se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
se::Stream* stream, int64 byte_size) override {
int64 byte_size) override {
if (allocated_) {
return se::port::InternalError(
"Can't allocate twice from a ScratchBufAllocator.");

View File

@ -32,20 +32,20 @@ FftScratchAllocator::FftScratchAllocator(
int device_ordinal, se::DeviceMemoryAllocator* memory_allocator)
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
int64 FftScratchAllocator::GetMemoryLimitInBytes() {
constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default.
return kFftScratchSize;
}
StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
se::Stream* stream, int64 byte_size) {
int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
if (byte_size > GetMemoryLimitInBytes()) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
absl::StrFormat(
"Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
byte_size, GetMemoryLimitInBytes()));
}
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,

View File

@ -40,12 +40,12 @@ class FftScratchAllocator : public se::ScratchAllocator {
FftScratchAllocator(int device_ordinal,
se::DeviceMemoryAllocator* memory_allocator);
int64 GetMemoryLimitInBytes(se::Stream* stream) override;
int64 GetMemoryLimitInBytes() override;
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
se::Stream* stream, int64 byte_size) override;
int64 byte_size) override;
private:
const int device_ordinal_;

View File

@ -110,7 +110,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
TF_ASSIGN_OR_RETURN(
se::cuda::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
allocator.CheckRedzones(stream));
allocator.CheckRedzones());
if (!rz_check_status.ok()) {
result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
*result.mutable_failure()->mutable_msg() =
@ -244,8 +244,7 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
se::cuda::RedzoneAllocator input_output_allocator(
executor->device_ordinal(), allocator,
PtxOptsFromConfig(hlo_module_config));
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
BufferComparator comparator(instr->shape(), hlo_module_config);
@ -254,7 +253,7 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(op->shape())));
ShapeUtil::ByteSizeOf(op->shape())));
InitializeFloatBuffer(&stream, op->shape().element_type(), &rng_state,
buffer);
return buffer;

View File

@ -265,10 +265,10 @@ class BlasScratchAllocator : public se::ScratchAllocator {
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
int64 GetMemoryLimitInBytes() override { return -1; }
se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
Stream* stream, int64 byte_size) override {
int64 byte_size) override {
Tensor temporary_memory;
Status allocation_status(context_->allocate_temp(

View File

@ -613,10 +613,9 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
// If violations have occurred, mark the corresponding autotune result
// as a failure.
static void CheckRedzones(const se::cuda::RedzoneAllocator& rz_allocator,
se::Stream* stream,
tensorflow::AutotuneResult* autotune_result) {
se::port::StatusOr<se::cuda::RedzoneAllocator::RedzoneCheckStatus> rz_status =
rz_allocator.CheckRedzones(stream);
rz_allocator.CheckRedzones();
if (!rz_status.ok()) {
static std::once_flag failure_logged;
std::call_once(failure_logged, [&]() {
@ -1003,14 +1002,12 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
se::TfAllocatorAdapter tf_allocator_adapter(
stream->parent()->platform(), ctx->device()->GetAllocator({}));
se::cuda::RedzoneAllocator rz_allocator(stream->parent()->device_ordinal(),
&tf_allocator_adapter,
se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
se::cuda::PtxCompilationOptions());
se::DeviceMemory<T> output_tensor;
if (!RedzoneCheckDisabled()) {
auto output_rz_or = rz_allocator.AllocateBytes(stream, output_ptr.size());
auto output_rz_or = rz_allocator.AllocateBytes(output_ptr.size());
if (!output_rz_or.ok()) {
static std::once_flag rz_allocation_failure_logged;
std::call_once(rz_allocation_failure_logged, []() {
@ -1033,8 +1030,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
se::cuda::RedzoneAllocator rz_scratch_allocator(
stream->parent()->device_ordinal(), &tf_allocator_adapter,
se::cuda::PtxCompilationOptions());
stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions());
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
se::ScratchAllocator* allocator_used =
!RedzoneCheckDisabled()
@ -1061,8 +1057,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
CheckRedzones(rz_scratch_allocator, stream, &result);
CheckRedzones(rz_allocator, stream, &result);
CheckRedzones(rz_scratch_allocator, &result);
CheckRedzones(rz_allocator, &result);
}
}
LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,

View File

@ -50,11 +50,9 @@ class DnnScratchAllocator : public se::ScratchAllocator {
virtual ~DnnScratchAllocator() {}
DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return memory_limit_;
}
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
se::Stream* stream, int64 byte_size) override {
int64 byte_size) override {
Tensor temporary_memory;
if (byte_size < 0) {
return se::port::Status{se::port::error::INVALID_ARGUMENT,

View File

@ -363,12 +363,11 @@ class CudnnRnnAllocatorInTemp : public ScratchAllocator {
explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
: context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override {
int64 GetMemoryLimitInBytes() override {
return std::numeric_limits<int64>::max();
}
StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override {
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
Tensor temporary_memory;
const DataType tf_data_type = ToTFDataType<T>::value;
int64 allocate_count =
@ -409,11 +408,10 @@ class CudnnRnnAllocatorInOutput : public ScratchAllocator {
~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
: context_(context), output_index_(output_index) {}
int64 GetMemoryLimitInBytes(Stream* stream) override {
int64 GetMemoryLimitInBytes() override {
return std::numeric_limits<int64>::max();
}
StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override {
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
CHECK(total_byte_size_ == 0)
<< "Reserve space allocator can only be called once";
int64 allocate_count =
@ -449,12 +447,11 @@ class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
~CudnnRNNPersistentSpaceAllocator() override {}
int64 GetMemoryLimitInBytes(Stream* stream) override {
int64 GetMemoryLimitInBytes() override {
return std::numeric_limits<int64>::max();
}
StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override {
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
if (total_byte_size_ != 0) {
return Status(error::FAILED_PRECONDITION,
"Persistent space allocator can only be called once");

View File

@ -315,11 +315,9 @@ class CufftScratchAllocator : public se::ScratchAllocator {
~CufftScratchAllocator() override {}
CufftScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return memory_limit_;
}
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
se::Stream* stream, int64 byte_size) override {
int64 byte_size) override {
Tensor temporary_memory;
if (byte_size > memory_limit_) {
return se::port::StatusOr<se::DeviceMemory<uint8>>();

View File

@ -101,12 +101,11 @@ class CudnnBatchNormAllocatorInTemp : public ScratchAllocator {
explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context)
: context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override {
int64 GetMemoryLimitInBytes() override {
return std::numeric_limits<int64>::max();
}
StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override {
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
Tensor temporary_memory;
const DataType tf_data_type = DataTypeToEnum<T>::v();
int64 allocate_count =
@ -155,12 +154,11 @@ class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index)
: context_(context), output_index_(output_index) {}
int64 GetMemoryLimitInBytes(Stream* stream) override {
int64 GetMemoryLimitInBytes() override {
return std::numeric_limits<int64>::max();
}
StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override {
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
output_allocated = true;
DCHECK(total_byte_size_ == 0)
<< "Reserve space allocator can only be called once";

View File

@ -2179,11 +2179,11 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
// whether a scratch allocator was passed.
if (scratch_allocator != nullptr) {
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
scratch_allocator->AllocateBytes(stream, size));
scratch_allocator->AllocateBytes(size));
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
scratch_allocator->AllocateBytes(stream, size));
scratch_allocator->AllocateBytes(size));
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
scratch_allocator->AllocateBytes(stream, size));
scratch_allocator->AllocateBytes(size));
a = DeviceMemory<CUDA_T *>(a_bytes);
b = DeviceMemory<CUDA_T *>(b_bytes);
c = DeviceMemory<CUDA_T *>(c_bytes);

View File

@ -952,8 +952,8 @@ class CudnnDropoutDescriptor {
size_t state_sizes_in_bytes = 0;
RETURN_IF_CUDNN_ERROR(
cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(
nullptr, state_sizes_in_bytes));
SE_ASSIGN_OR_RETURN(state_memory,
state_allocator->AllocateBytes(state_sizes_in_bytes));
}
RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
@ -1603,7 +1603,7 @@ port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
if (workspace_size_in_bytes == 0) {
return DeviceMemory<uint8>();
}
return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
#if CUDNN_VERSION >= 7402
@ -1628,7 +1628,7 @@ port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
if (workspace_size_in_bytes == 0) {
return DeviceMemory<uint8>();
}
return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
@ -1652,7 +1652,7 @@ port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
if (workspace_size_in_bytes == 0) {
return DeviceMemory<uint8>();
}
return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
#endif
@ -1701,9 +1701,8 @@ port::Status CudnnSupport::DoRnnForwardImpl(
/*sizeInBytes=*/&reserve_space_size_in_bytes));
if (reserve_space_size_in_bytes > 0) {
SE_ASSIGN_OR_RETURN(reserve_space,
reserve_space_allocator->AllocateBytes(
stream, reserve_space_size_in_bytes));
SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
reserve_space_size_in_bytes));
}
}
@ -2401,7 +2400,7 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
"No scratch allocator provided");
}
return scratch_allocator->AllocateBytes(stream, size_in_bytes);
return scratch_allocator->AllocateBytes(size_in_bytes);
}
port::StatusOr<DeviceMemory<uint8>>
@ -2446,7 +2445,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
"No scratch allocator provided");
}
return scratch_allocator->AllocateBytes(stream, size_in_bytes);
return scratch_allocator->AllocateBytes(size_in_bytes);
}
port::StatusOr<DeviceMemory<uint8>>
@ -2491,7 +2490,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
"No scratch allocator provided");
}
return scratch_allocator->AllocateBytes(stream, size_in_bytes);
return scratch_allocator->AllocateBytes(size_in_bytes);
}
static bool TensorOpMathAvailable(int cc_major) {
@ -2512,7 +2511,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
bool specify_workspace_limit = scratch_allocator != nullptr;
auto memory_limit_bytes =
specify_workspace_limit
? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
: 0ll;
SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
GetCudnnConvolutionForwardAlgo(
@ -2565,7 +2564,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
bool specify_workspace_limit = scratch_allocator != nullptr;
auto memory_limit_bytes =
specify_workspace_limit
? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
: 0ll;
SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
GetCudnnConvolutionBackwardDataAlgo(
@ -2617,7 +2616,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
bool specify_workspace_limit = scratch_allocator != nullptr;
auto memory_limit_bytes =
specify_workspace_limit
? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll)
: 0ll;
SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
GetCudnnConvolutionBackwardFilterAlgo(
@ -3470,9 +3469,8 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
/*activationDesc=*/activation_desc.handle(),
/*xDesc=*/x_descriptor.handle(),
/*sizeInBytes=*/&reserve_space_size_in_bytes));
SE_ASSIGN_OR_RETURN(reserve_space,
reserve_space_allocator->AllocateBytes(
stream, reserve_space_size_in_bytes));
SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
reserve_space_size_in_bytes));
}
}
#endif

View File

@ -244,8 +244,7 @@ port::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
port::Status CUDAFftPlan::UpdateScratchAllocator(
Stream *stream, ScratchAllocator *scratch_allocator) {
if (scratch_size_bytes_ != 0) {
auto allocated =
scratch_allocator->AllocateBytes(stream, scratch_size_bytes_);
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();

View File

@ -45,10 +45,11 @@ constexpr int64 kRhsRedzoneAlign = 4;
using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus;
RedzoneAllocator::RedzoneAllocator(
int device_ordinal, DeviceMemoryAllocator* memory_allocator,
Stream* stream, DeviceMemoryAllocator* memory_allocator,
cuda::PtxCompilationOptions ptx_compilation_opts, uint64 redzone_size,
uint8 redzone_pattern)
: device_ordinal_(device_ordinal),
: device_ordinal_(stream->parent()->device_ordinal()),
stream_(stream),
redzone_size_(RoundUpToNearest(
redzone_size,
static_cast<uint64>(tensorflow::Allocator::kAllocatorAlignment))),
@ -57,14 +58,14 @@ RedzoneAllocator::RedzoneAllocator(
ptx_compilation_opts_(ptx_compilation_opts) {}
port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
Stream* stream, int64 byte_size) {
int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
if (byte_size > GetMemoryLimitInBytes()) {
return port::Status(
port::error::RESOURCE_EXHAUSTED,
absl::StrFormat(
"Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
byte_size, GetMemoryLimitInBytes()));
}
int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size;
@ -78,10 +79,10 @@ port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
static_assert(sizeof(uint8) == 1, "Unexpected size");
DeviceMemory<uint8> allocated_buffer_memory(*allocated_buffer);
DeviceMemory<uint8> lhs_redzone = stream->parent()->GetSubBuffer(
DeviceMemory<uint8> lhs_redzone = stream_->parent()->GetSubBuffer(
&allocated_buffer_memory, 0, redzone_size_);
DeviceMemory<uint8> data_chunk = stream->parent()->GetSubBuffer(
DeviceMemory<uint8> data_chunk = stream_->parent()->GetSubBuffer(
&allocated_buffer_memory, redzone_size_, byte_size);
// Split up the RHS redzone into two pieces:
@ -89,10 +90,10 @@ port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
// - redzone_size_ bytes.
// We do this because Stream::ThenMemset32 requires the buffer address and
// size to be aligned to 4 bytes.
DeviceMemory<uint8> rhs_redzone_slop = stream->parent()->GetSubBuffer(
DeviceMemory<uint8> rhs_redzone_slop = stream_->parent()->GetSubBuffer(
&allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop);
DeviceMemory<uint8> rhs_redzone_nonslop = stream->parent()->GetSubBuffer(
DeviceMemory<uint8> rhs_redzone_nonslop = stream_->parent()->GetSubBuffer(
&allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop,
redzone_size_);
@ -100,11 +101,11 @@ port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
redzone_pattern_};
uint32 pattern32;
std::memcpy(&pattern32, pattern_arr, sizeof(pattern32));
stream->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
if (rhs_slop != 0) {
stream->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
}
stream->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size);
return data_chunk;
@ -295,9 +296,8 @@ static port::StatusOr<RedzoneCheckStatus> CheckRedzonesForBuffer(
return RedzoneCheckStatus::OK();
}
port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones(
Stream* stream) const {
StreamExecutor* executor = stream->parent();
port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones() const {
StreamExecutor* executor = stream_->parent();
absl::Span<const uint8> compiled_ptx = {};
port::StatusOr<absl::Span<const uint8>> compiled_ptx_or =
@ -316,7 +316,7 @@ port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones(
ScopedDeviceMemory<uint64> out_param =
executor->AllocateOwnedScalar<uint64>();
stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
stream_->ThenMemZero(out_param.ptr(), sizeof(uint64));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<ComparisonKernelT> comparison_kernel,
@ -327,7 +327,7 @@ port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones(
for (const auto& buf_and_size : allocated_buffers_) {
TF_ASSIGN_OR_RETURN(
RedzoneCheckStatus redzone_status,
CheckRedzonesForBuffer(stream, *buf_and_size.first, out_param.cref(),
CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(),
*comparison_kernel, buf_and_size.second,
redzone_size_, redzone_pattern_));
if (!redzone_status.ok()) {

View File

@ -39,21 +39,20 @@ namespace cuda {
// memory for cudnn convolutions.
class RedzoneAllocator : public ScratchAllocator {
public:
RedzoneAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator,
RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator,
cuda::PtxCompilationOptions ptx_compilation_opts,
uint64 redzone_size = 1 << 23, // 8MiB per side, 16MiB total
uint8 redzone_pattern = -1);
// Redzones don't count towards the memory limit.
int64 GetMemoryLimitInBytes(Stream* stream) override {
int64 GetMemoryLimitInBytes() override {
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
}
int64 TotalAllocatedBytesExcludingRedzones() const {
return allocated_bytes_excluding_redzones_;
}
port::StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override;
port::StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
// Non-empty redzone check status implies that there was a write into a
// redzone, with a string communicating the location of the write.
@ -92,10 +91,11 @@ class RedzoneAllocator : public ScratchAllocator {
// - RedzoneCheckStatus with a non-empty error message iff a write into a
// redzone has been detected.
// - A stream error, if loading or launching the kernel has failed.
port::StatusOr<RedzoneCheckStatus> CheckRedzones(Stream* stream) const;
port::StatusOr<RedzoneCheckStatus> CheckRedzones() const;
private:
const int device_ordinal_;
Stream* stream_;
// Redzone size on *one side* of allocation.
//

View File

@ -55,15 +55,14 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) {
StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie();
cuda::PtxCompilationOptions opts;
StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec});
RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, opts,
kRedzoneSize, kRedzonePattern);
Stream stream(stream_exec);
stream.Init();
RedzoneAllocator allocator(&stream, &se_allocator, opts, kRedzoneSize,
kRedzonePattern);
TF_ASSERT_OK_AND_ASSIGN(DeviceMemory<uint8> buf,
allocator.AllocateBytes(&stream,
/*byte_size=*/kAllocSize));
EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
allocator.AllocateBytes(/*byte_size=*/kAllocSize));
EXPECT_REDZONE_OK(allocator.CheckRedzones());
char* buf_addr = reinterpret_cast<char*>(buf.opaque());
DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize);
@ -100,15 +99,13 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) {
DeviceMemoryBase redzone_at_offset(
reinterpret_cast<char*>(redzone.opaque()) + offset, 1);
char old_redzone_value = 0;
{
EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
}
{ EXPECT_REDZONE_OK(allocator.CheckRedzones()); }
stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1)
.ThenMemZero(&redzone_at_offset, 1);
EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones(&stream));
EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones());
// Checking reinitializes the redzone.
EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
EXPECT_REDZONE_OK(allocator.CheckRedzones());
};
modify_redzone(lhs_redzone, /*offset=*/0, "lhs");
@ -130,12 +127,12 @@ TEST(RedzoneAllocatorTest, VeryLargeRedzone) {
StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie();
cuda::PtxCompilationOptions opts;
StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec});
RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, opts,
kRedzoneSize, /*redzone_pattern=*/-1);
Stream stream(stream_exec);
stream.Init();
(void)allocator.AllocateBytes(&stream, /*byte_size=*/1);
EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream));
RedzoneAllocator allocator(&stream, &se_allocator, opts, kRedzoneSize,
/*redzone_pattern=*/-1);
(void)allocator.AllocateBytes(/*byte_size=*/1);
EXPECT_REDZONE_OK(allocator.CheckRedzones());
}
} // namespace

View File

@ -22,18 +22,17 @@ namespace stream_executor {
ScratchAllocator::~ScratchAllocator() {}
OneTimeScratchAllocator::OneTimeScratchAllocator() {}
OneTimeScratchAllocator::OneTimeScratchAllocator(Stream* stream)
: stream_(stream) {}
OneTimeScratchAllocator::~OneTimeScratchAllocator() {}
int64 OneTimeScratchAllocator::GetMemoryLimitInBytes(Stream* stream) {
return -1;
}
int64 OneTimeScratchAllocator::GetMemoryLimitInBytes() { return -1; }
port::StatusOr<DeviceMemory<uint8>> OneTimeScratchAllocator::AllocateBytes(
Stream* stream, int64 byte_size) {
int64 byte_size) {
CHECK(temporary_ == nullptr);
SE_ASSIGN_OR_RETURN(temporary_,
stream->AllocateTemporaryArray<uint8>(byte_size));
stream_->AllocateTemporaryArray<uint8>(byte_size));
return temporary_->device_memory();
}

View File

@ -45,14 +45,14 @@ class ScratchAllocator {
// bytes. This information may be used to help select an algorithm.
//
// Returns values < 0 to indicate that there is no recommended limit.
virtual int64 GetMemoryLimitInBytes(Stream* stream) = 0;
virtual int64 GetMemoryLimitInBytes() = 0;
// Returns an allocation on byte_size bytes for use in an operation on stream.
//
// This is a temporary allocation, and the caller is responsible for
// deallocating at some known-safe point. See the class comment above.
virtual port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
Stream* stream, int64 byte_size) = 0;
int64 byte_size) = 0;
};
// Allocates a single temporary memory allocation -- this memory is deallocated
@ -64,14 +64,14 @@ class ScratchAllocator {
// thread will request the scratch allocation).
class OneTimeScratchAllocator : public ScratchAllocator {
public:
OneTimeScratchAllocator();
explicit OneTimeScratchAllocator(Stream* stream);
~OneTimeScratchAllocator() override;
int64 GetMemoryLimitInBytes(Stream* stream) override;
port::StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
int64 byte_size) override;
int64 GetMemoryLimitInBytes() override;
port::StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
private:
std::unique_ptr<TemporaryDeviceMemory<uint8>> temporary_;
Stream* stream_;
SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator);
};