From 386a8d770269c7814b73af13521b8547b3ca481d Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 5 Aug 2019 11:17:46 -0700 Subject: [PATCH] [SE] Remove Stream* argument from ScratchAllocator methods The arguments are unused (apart from the RedzoneAllocator, which is only used on a single stream), and are unnecessarily propagated through the wrappings. PiperOrigin-RevId: 261726698 --- .../gpu/cudnn_conv_algorithm_picker.cc | 16 ++++----- .../xla/service/gpu/cudnn_conv_runner.cc | 6 ++-- .../compiler/xla/service/gpu/fft_thunk.cc | 8 ++--- .../compiler/xla/service/gpu/fft_thunk.h | 4 +-- .../xla/service/gpu/gemm_algorithm_picker.cc | 7 ++-- .../core/kernels/batch_matmul_op_impl.h | 4 +-- tensorflow/core/kernels/conv_ops.cc | 16 ++++----- tensorflow/core/kernels/conv_ops_gpu.h | 6 ++-- tensorflow/core/kernels/cudnn_rnn_ops.cc | 15 ++++---- tensorflow/core/kernels/fft_ops.cc | 6 ++-- .../core/kernels/fused_batch_norm_op.cc | 10 +++--- tensorflow/stream_executor/cuda/cuda_blas.cc | 6 ++-- tensorflow/stream_executor/cuda/cuda_dnn.cc | 32 ++++++++--------- tensorflow/stream_executor/cuda/cuda_fft.cc | 3 +- .../stream_executor/cuda/redzone_allocator.cc | 34 +++++++++---------- .../stream_executor/cuda/redzone_allocator.h | 10 +++--- .../cuda/redzone_allocator_test.cc | 25 ++++++-------- .../stream_executor/scratch_allocator.cc | 11 +++--- .../stream_executor/scratch_allocator.h | 12 +++---- 19 files changed, 102 insertions(+), 129 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 99566c4aa11..46886d8df3e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -143,10 +143,8 @@ StatusOr CheckRedzones(const se::cuda::RedzoneAllocator& allocator, XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones", 2); using RedzoneCheckStatus = se::cuda::RedzoneAllocator::RedzoneCheckStatus; - TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check, - allocator.CheckRedzones(stream)); - + allocator.CheckRedzones()); if (redzone_check.ok()) { return true; } @@ -253,8 +251,6 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); - const auto device_ordinal = stream_exec_->device_ordinal(); - // allocator either points to this->allocator_ or, if that's null, to a // se::StreamExecutorMemoryAllocator for stream_exec_. se::DeviceMemoryAllocator* allocator; @@ -278,18 +274,18 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( // Allocate space for the input, filter, and output of the convolution. se::cuda::RedzoneAllocator input_output_allocator( - device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config)); + &stream, allocator, PtxOptsFromConfig(hlo_module_config)); std::vector operand_buffers; for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + ShapeUtil::ByteSizeOf(operand->shape()))); initialize_buffer(buffer); operand_buffers.push_back(buffer); } TF_ASSIGN_OR_RETURN(auto result_buffer, input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(result_shape))); + ShapeUtil::ByteSizeOf(result_shape))); initialize_buffer(result_buffer); TF_ASSIGN_OR_RETURN(auto backend_config, @@ -331,7 +327,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( } se::cuda::RedzoneAllocator scratch_allocator( - device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config)); + &stream, allocator, PtxOptsFromConfig(hlo_module_config)); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); @@ -431,7 +427,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( comparator.emplace(result_shape, hlo_module_config); TF_ASSIGN_OR_RETURN( reference_result_buffer, - input_output_allocator.AllocateBytes(&stream, result_buffer.size())); + input_output_allocator.AllocateBytes(result_buffer.size())); stream.ThenMemcpy(&reference_result_buffer, result_buffer, result_buffer.size()); first_algorithm = alg; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 5aa76ac0140..da5059e05c7 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -48,12 +48,10 @@ class ScratchBufAllocator : public se::ScratchAllocator { ~ScratchBufAllocator() override = default; - int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { - return scratch_.size(); - } + int64 GetMemoryLimitInBytes() override { return scratch_.size(); } se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override { + int64 byte_size) override { if (allocated_) { return se::port::InternalError( "Can't allocate twice from a ScratchBufAllocator."); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index da90ba989dc..991a463f2a0 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -32,20 +32,20 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, se::DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { +int64 FftScratchAllocator::GetMemoryLimitInBytes() { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } StatusOr> FftScratchAllocator::AllocateBytes( - se::Stream* stream, int64 byte_size) { + int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { + if (byte_size > GetMemoryLimitInBytes()) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, absl::StrFormat( "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes(stream))); + byte_size, GetMemoryLimitInBytes())); } TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index be77df1eb77..95186c7f219 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -40,12 +40,12 @@ class FftScratchAllocator : public se::ScratchAllocator { FftScratchAllocator(int device_ordinal, se::DeviceMemoryAllocator* memory_allocator); - int64 GetMemoryLimitInBytes(se::Stream* stream) override; + int64 GetMemoryLimitInBytes() override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + int64 byte_size) override; private: const int device_ordinal_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 626bef76b98..24a2dced50c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -110,7 +110,7 @@ static StatusOr> DoUncachedGemmAutotune( TF_ASSIGN_OR_RETURN( se::cuda::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones(stream)); + allocator.CheckRedzones()); if (!rz_check_status.ok()) { result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); *result.mutable_failure()->mutable_msg() = @@ -244,8 +244,7 @@ static StatusOr RunOnInstruction(HloInstruction* instr, const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); se::cuda::RedzoneAllocator input_output_allocator( - executor->device_ordinal(), allocator, - PtxOptsFromConfig(hlo_module_config)); + &stream, allocator, PtxOptsFromConfig(hlo_module_config)); BufferComparator comparator(instr->shape(), hlo_module_config); @@ -254,7 +253,7 @@ static StatusOr RunOnInstruction(HloInstruction* instr, [&](const HloInstruction* op) -> StatusOr { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(op->shape()))); + ShapeUtil::ByteSizeOf(op->shape()))); InitializeFloatBuffer(&stream, op->shape().element_type(), &rng_state, buffer); return buffer; diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 84f7571d6a4..1e85dbcfc15 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -265,10 +265,10 @@ class BlasScratchAllocator : public se::ScratchAllocator { BlasScratchAllocator(OpKernelContext* context) : context_(context) {} - int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; } + int64 GetMemoryLimitInBytes() override { return -1; } se::port::StatusOr AllocateBytes( - Stream* stream, int64 byte_size) override { + int64 byte_size) override { Tensor temporary_memory; Status allocation_status(context_->allocate_temp( diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 4ea31861e7a..637098884a5 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -613,10 +613,9 @@ typedef AutoTuneSingleton rz_status = - rz_allocator.CheckRedzones(stream); + rz_allocator.CheckRedzones(); if (!rz_status.ok()) { static std::once_flag failure_logged; std::call_once(failure_logged, [&]() { @@ -1003,14 +1002,12 @@ void LaunchConv2DOp::operator()( se::TfAllocatorAdapter tf_allocator_adapter( stream->parent()->platform(), ctx->device()->GetAllocator({})); - se::cuda::RedzoneAllocator rz_allocator(stream->parent()->device_ordinal(), - &tf_allocator_adapter, + se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions()); - se::DeviceMemory output_tensor; if (!RedzoneCheckDisabled()) { - auto output_rz_or = rz_allocator.AllocateBytes(stream, output_ptr.size()); + auto output_rz_or = rz_allocator.AllocateBytes(output_ptr.size()); if (!output_rz_or.ok()) { static std::once_flag rz_allocation_failure_logged; std::call_once(rz_allocation_failure_logged, []() { @@ -1033,8 +1030,7 @@ void LaunchConv2DOp::operator()( // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. se::cuda::RedzoneAllocator rz_scratch_allocator( - stream->parent()->device_ordinal(), &tf_allocator_adapter, - se::cuda::PtxCompilationOptions()); + stream, &tf_allocator_adapter, se::cuda::PtxCompilationOptions()); DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); se::ScratchAllocator* allocator_used = !RedzoneCheckDisabled() @@ -1061,8 +1057,8 @@ void LaunchConv2DOp::operator()( *result.mutable_run_time() = proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); - CheckRedzones(rz_scratch_allocator, stream, &result); - CheckRedzones(rz_allocator, stream, &result); + CheckRedzones(rz_scratch_allocator, &result); + CheckRedzones(rz_allocator, &result); } } LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD, diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 7906f74c616..c2c89ecdb9b 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -50,11 +50,9 @@ class DnnScratchAllocator : public se::ScratchAllocator { virtual ~DnnScratchAllocator() {} DnnScratchAllocator(int64 memory_limit, OpKernelContext* context) : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} - int64 GetMemoryLimitInBytes(se::Stream* stream) override { - return memory_limit_; - } + int64 GetMemoryLimitInBytes() override { return memory_limit_; } se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override { + int64 byte_size) override { Tensor temporary_memory; if (byte_size < 0) { return se::port::Status{se::port::error::INVALID_ARGUMENT, diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 09826f57ce5..d8e5d2abc88 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -363,12 +363,11 @@ class CudnnRnnAllocatorInTemp : public ScratchAllocator { explicit CudnnRnnAllocatorInTemp(OpKernelContext* context) : context_(context) {} - int64 GetMemoryLimitInBytes(Stream* stream) override { + int64 GetMemoryLimitInBytes() override { return std::numeric_limits::max(); } - StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override { + StatusOr> AllocateBytes(int64 byte_size) override { Tensor temporary_memory; const DataType tf_data_type = ToTFDataType::value; int64 allocate_count = @@ -409,11 +408,10 @@ class CudnnRnnAllocatorInOutput : public ScratchAllocator { ~CudnnRnnAllocatorInOutput() override {} CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index) : context_(context), output_index_(output_index) {} - int64 GetMemoryLimitInBytes(Stream* stream) override { + int64 GetMemoryLimitInBytes() override { return std::numeric_limits::max(); } - StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override { + StatusOr> AllocateBytes(int64 byte_size) override { CHECK(total_byte_size_ == 0) << "Reserve space allocator can only be called once"; int64 allocate_count = @@ -449,12 +447,11 @@ class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator { ~CudnnRNNPersistentSpaceAllocator() override {} - int64 GetMemoryLimitInBytes(Stream* stream) override { + int64 GetMemoryLimitInBytes() override { return std::numeric_limits::max(); } - StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override { + StatusOr> AllocateBytes(int64 byte_size) override { if (total_byte_size_ != 0) { return Status(error::FAILED_PRECONDITION, "Persistent space allocator can only be called once"); diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index e0f326dcea3..fabd8e9cb36 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -315,11 +315,9 @@ class CufftScratchAllocator : public se::ScratchAllocator { ~CufftScratchAllocator() override {} CufftScratchAllocator(int64 memory_limit, OpKernelContext* context) : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} - int64 GetMemoryLimitInBytes(se::Stream* stream) override { - return memory_limit_; - } + int64 GetMemoryLimitInBytes() override { return memory_limit_; } se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override { + int64 byte_size) override { Tensor temporary_memory; if (byte_size > memory_limit_) { return se::port::StatusOr>(); diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 70bd659be66..dd75b3718ae 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -101,12 +101,11 @@ class CudnnBatchNormAllocatorInTemp : public ScratchAllocator { explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context) : context_(context) {} - int64 GetMemoryLimitInBytes(Stream* stream) override { + int64 GetMemoryLimitInBytes() override { return std::numeric_limits::max(); } - StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override { + StatusOr> AllocateBytes(int64 byte_size) override { Tensor temporary_memory; const DataType tf_data_type = DataTypeToEnum::v(); int64 allocate_count = @@ -155,12 +154,11 @@ class CudnnBatchNormAllocatorInOutput : public ScratchAllocator { CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index) : context_(context), output_index_(output_index) {} - int64 GetMemoryLimitInBytes(Stream* stream) override { + int64 GetMemoryLimitInBytes() override { return std::numeric_limits::max(); } - StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override { + StatusOr> AllocateBytes(int64 byte_size) override { output_allocated = true; DCHECK(total_byte_size_ == 0) << "Reserve space allocator can only be called once"; diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 421b9b4ce42..742181d9249 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -2179,11 +2179,11 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( // whether a scratch allocator was passed. if (scratch_allocator != nullptr) { SE_ASSIGN_OR_RETURN(DeviceMemory a_bytes, - scratch_allocator->AllocateBytes(stream, size)); + scratch_allocator->AllocateBytes(size)); SE_ASSIGN_OR_RETURN(DeviceMemory b_bytes, - scratch_allocator->AllocateBytes(stream, size)); + scratch_allocator->AllocateBytes(size)); SE_ASSIGN_OR_RETURN(DeviceMemory c_bytes, - scratch_allocator->AllocateBytes(stream, size)); + scratch_allocator->AllocateBytes(size)); a = DeviceMemory(a_bytes); b = DeviceMemory(b_bytes); c = DeviceMemory(c_bytes); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 207b7201527..659214c4aab 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -952,8 +952,8 @@ class CudnnDropoutDescriptor { size_t state_sizes_in_bytes = 0; RETURN_IF_CUDNN_ERROR( cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes)); - SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes( - nullptr, state_sizes_in_bytes)); + SE_ASSIGN_OR_RETURN(state_memory, + state_allocator->AllocateBytes(state_sizes_in_bytes)); } RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor( handle.get(), cudnn.handle(), dropout, state_memory.opaque(), @@ -1603,7 +1603,7 @@ port::StatusOr> CreateRnnWorkspace( if (workspace_size_in_bytes == 0) { return DeviceMemory(); } - return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); + return workspace_allocator->AllocateBytes(workspace_size_in_bytes); } #if CUDNN_VERSION >= 7402 @@ -1628,7 +1628,7 @@ port::StatusOr> CreateBatchNormForwardWorkspace( if (workspace_size_in_bytes == 0) { return DeviceMemory(); } - return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); + return workspace_allocator->AllocateBytes(workspace_size_in_bytes); } port::StatusOr> CreateBatchNormBackwardWorkspace( @@ -1652,7 +1652,7 @@ port::StatusOr> CreateBatchNormBackwardWorkspace( if (workspace_size_in_bytes == 0) { return DeviceMemory(); } - return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); + return workspace_allocator->AllocateBytes(workspace_size_in_bytes); } #endif @@ -1701,9 +1701,8 @@ port::Status CudnnSupport::DoRnnForwardImpl( /*sizeInBytes=*/&reserve_space_size_in_bytes)); if (reserve_space_size_in_bytes > 0) { - SE_ASSIGN_OR_RETURN(reserve_space, - reserve_space_allocator->AllocateBytes( - stream, reserve_space_size_in_bytes)); + SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); } } @@ -2401,7 +2400,7 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( "No scratch allocator provided"); } - return scratch_allocator->AllocateBytes(stream, size_in_bytes); + return scratch_allocator->AllocateBytes(size_in_bytes); } port::StatusOr> @@ -2446,7 +2445,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( "No scratch allocator provided"); } - return scratch_allocator->AllocateBytes(stream, size_in_bytes); + return scratch_allocator->AllocateBytes(size_in_bytes); } port::StatusOr> @@ -2491,7 +2490,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( "No scratch allocator provided"); } - return scratch_allocator->AllocateBytes(stream, size_in_bytes); + return scratch_allocator->AllocateBytes(size_in_bytes); } static bool TensorOpMathAvailable(int cc_major) { @@ -2512,7 +2511,7 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( bool specify_workspace_limit = scratch_allocator != nullptr; auto memory_limit_bytes = specify_workspace_limit - ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll) : 0ll; SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo, GetCudnnConvolutionForwardAlgo( @@ -2565,7 +2564,7 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( bool specify_workspace_limit = scratch_allocator != nullptr; auto memory_limit_bytes = specify_workspace_limit - ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll) : 0ll; SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo, GetCudnnConvolutionBackwardDataAlgo( @@ -2617,7 +2616,7 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( bool specify_workspace_limit = scratch_allocator != nullptr; auto memory_limit_bytes = specify_workspace_limit - ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + ? std::max(scratch_allocator->GetMemoryLimitInBytes(), 0ll) : 0ll; SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo, GetCudnnConvolutionBackwardFilterAlgo( @@ -3470,9 +3469,8 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl( /*activationDesc=*/activation_desc.handle(), /*xDesc=*/x_descriptor.handle(), /*sizeInBytes=*/&reserve_space_size_in_bytes)); - SE_ASSIGN_OR_RETURN(reserve_space, - reserve_space_allocator->AllocateBytes( - stream, reserve_space_size_in_bytes)); + SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); } } #endif diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc index 3bf2f5b9742..79047d989bb 100644 --- a/tensorflow/stream_executor/cuda/cuda_fft.cc +++ b/tensorflow/stream_executor/cuda/cuda_fft.cc @@ -244,8 +244,7 @@ port::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream, port::Status CUDAFftPlan::UpdateScratchAllocator( Stream *stream, ScratchAllocator *scratch_allocator) { if (scratch_size_bytes_ != 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, scratch_size_bytes_); + auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_); if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) { LOG(ERROR) << "failed to allocate work area."; return allocated.status(); diff --git a/tensorflow/stream_executor/cuda/redzone_allocator.cc b/tensorflow/stream_executor/cuda/redzone_allocator.cc index 76ff86cbdd5..cebf5852403 100644 --- a/tensorflow/stream_executor/cuda/redzone_allocator.cc +++ b/tensorflow/stream_executor/cuda/redzone_allocator.cc @@ -45,10 +45,11 @@ constexpr int64 kRhsRedzoneAlign = 4; using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; RedzoneAllocator::RedzoneAllocator( - int device_ordinal, DeviceMemoryAllocator* memory_allocator, + Stream* stream, DeviceMemoryAllocator* memory_allocator, cuda::PtxCompilationOptions ptx_compilation_opts, uint64 redzone_size, uint8 redzone_pattern) - : device_ordinal_(device_ordinal), + : device_ordinal_(stream->parent()->device_ordinal()), + stream_(stream), redzone_size_(RoundUpToNearest( redzone_size, static_cast(tensorflow::Allocator::kAllocatorAlignment))), @@ -57,14 +58,14 @@ RedzoneAllocator::RedzoneAllocator( ptx_compilation_opts_(ptx_compilation_opts) {} port::StatusOr> RedzoneAllocator::AllocateBytes( - Stream* stream, int64 byte_size) { + int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { + if (byte_size > GetMemoryLimitInBytes()) { return port::Status( port::error::RESOURCE_EXHAUSTED, absl::StrFormat( "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes(stream))); + byte_size, GetMemoryLimitInBytes())); } int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size; @@ -78,10 +79,10 @@ port::StatusOr> RedzoneAllocator::AllocateBytes( static_assert(sizeof(uint8) == 1, "Unexpected size"); DeviceMemory allocated_buffer_memory(*allocated_buffer); - DeviceMemory lhs_redzone = stream->parent()->GetSubBuffer( + DeviceMemory lhs_redzone = stream_->parent()->GetSubBuffer( &allocated_buffer_memory, 0, redzone_size_); - DeviceMemory data_chunk = stream->parent()->GetSubBuffer( + DeviceMemory data_chunk = stream_->parent()->GetSubBuffer( &allocated_buffer_memory, redzone_size_, byte_size); // Split up the RHS redzone into two pieces: @@ -89,10 +90,10 @@ port::StatusOr> RedzoneAllocator::AllocateBytes( // - redzone_size_ bytes. // We do this because Stream::ThenMemset32 requires the buffer address and // size to be aligned to 4 bytes. - DeviceMemory rhs_redzone_slop = stream->parent()->GetSubBuffer( + DeviceMemory rhs_redzone_slop = stream_->parent()->GetSubBuffer( &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop); - DeviceMemory rhs_redzone_nonslop = stream->parent()->GetSubBuffer( + DeviceMemory rhs_redzone_nonslop = stream_->parent()->GetSubBuffer( &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop, redzone_size_); @@ -100,11 +101,11 @@ port::StatusOr> RedzoneAllocator::AllocateBytes( redzone_pattern_}; uint32 pattern32; std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); - stream->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); + stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); if (rhs_slop != 0) { - stream->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop); + stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop); } - stream->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_); + stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_); allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size); return data_chunk; @@ -295,9 +296,8 @@ static port::StatusOr CheckRedzonesForBuffer( return RedzoneCheckStatus::OK(); } -port::StatusOr RedzoneAllocator::CheckRedzones( - Stream* stream) const { - StreamExecutor* executor = stream->parent(); +port::StatusOr RedzoneAllocator::CheckRedzones() const { + StreamExecutor* executor = stream_->parent(); absl::Span compiled_ptx = {}; port::StatusOr> compiled_ptx_or = @@ -316,7 +316,7 @@ port::StatusOr RedzoneAllocator::CheckRedzones( ScopedDeviceMemory out_param = executor->AllocateOwnedScalar(); - stream->ThenMemZero(out_param.ptr(), sizeof(uint64)); + stream_->ThenMemZero(out_param.ptr(), sizeof(uint64)); TF_ASSIGN_OR_RETURN( std::unique_ptr comparison_kernel, @@ -327,7 +327,7 @@ port::StatusOr RedzoneAllocator::CheckRedzones( for (const auto& buf_and_size : allocated_buffers_) { TF_ASSIGN_OR_RETURN( RedzoneCheckStatus redzone_status, - CheckRedzonesForBuffer(stream, *buf_and_size.first, out_param.cref(), + CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(), *comparison_kernel, buf_and_size.second, redzone_size_, redzone_pattern_)); if (!redzone_status.ok()) { diff --git a/tensorflow/stream_executor/cuda/redzone_allocator.h b/tensorflow/stream_executor/cuda/redzone_allocator.h index 42ddd99b7ce..c78b54e0c5f 100644 --- a/tensorflow/stream_executor/cuda/redzone_allocator.h +++ b/tensorflow/stream_executor/cuda/redzone_allocator.h @@ -39,21 +39,20 @@ namespace cuda { // memory for cudnn convolutions. class RedzoneAllocator : public ScratchAllocator { public: - RedzoneAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator, + RedzoneAllocator(Stream* stream, DeviceMemoryAllocator* memory_allocator, cuda::PtxCompilationOptions ptx_compilation_opts, uint64 redzone_size = 1 << 23, // 8MiB per side, 16MiB total uint8 redzone_pattern = -1); // Redzones don't count towards the memory limit. - int64 GetMemoryLimitInBytes(Stream* stream) override { + int64 GetMemoryLimitInBytes() override { return 1LL << 32; // 4GB. TODO(jlebar): Tune this? } int64 TotalAllocatedBytesExcludingRedzones() const { return allocated_bytes_excluding_redzones_; } - port::StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override; + port::StatusOr> AllocateBytes(int64 byte_size) override; // Non-empty redzone check status implies that there was a write into a // redzone, with a string communicating the location of the write. @@ -92,10 +91,11 @@ class RedzoneAllocator : public ScratchAllocator { // - RedzoneCheckStatus with a non-empty error message iff a write into a // redzone has been detected. // - A stream error, if loading or launching the kernel has failed. - port::StatusOr CheckRedzones(Stream* stream) const; + port::StatusOr CheckRedzones() const; private: const int device_ordinal_; + Stream* stream_; // Redzone size on *one side* of allocation. // diff --git a/tensorflow/stream_executor/cuda/redzone_allocator_test.cc b/tensorflow/stream_executor/cuda/redzone_allocator_test.cc index 23fee5164e5..9f6d1bd6046 100644 --- a/tensorflow/stream_executor/cuda/redzone_allocator_test.cc +++ b/tensorflow/stream_executor/cuda/redzone_allocator_test.cc @@ -55,15 +55,14 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie(); cuda::PtxCompilationOptions opts; StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); - RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, opts, - kRedzoneSize, kRedzonePattern); Stream stream(stream_exec); stream.Init(); + RedzoneAllocator allocator(&stream, &se_allocator, opts, kRedzoneSize, + kRedzonePattern); TF_ASSERT_OK_AND_ASSIGN(DeviceMemory buf, - allocator.AllocateBytes(&stream, - /*byte_size=*/kAllocSize)); - EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); + allocator.AllocateBytes(/*byte_size=*/kAllocSize)); + EXPECT_REDZONE_OK(allocator.CheckRedzones()); char* buf_addr = reinterpret_cast(buf.opaque()); DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize); @@ -100,15 +99,13 @@ TEST(RedzoneAllocatorTest, WriteToRedzone) { DeviceMemoryBase redzone_at_offset( reinterpret_cast(redzone.opaque()) + offset, 1); char old_redzone_value = 0; - { - EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); - } + { EXPECT_REDZONE_OK(allocator.CheckRedzones()); } stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1) .ThenMemZero(&redzone_at_offset, 1); - EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones(&stream)); + EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones()); // Checking reinitializes the redzone. - EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); + EXPECT_REDZONE_OK(allocator.CheckRedzones()); }; modify_redzone(lhs_redzone, /*offset=*/0, "lhs"); @@ -130,12 +127,12 @@ TEST(RedzoneAllocatorTest, VeryLargeRedzone) { StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie(); cuda::PtxCompilationOptions opts; StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); - RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, opts, - kRedzoneSize, /*redzone_pattern=*/-1); Stream stream(stream_exec); stream.Init(); - (void)allocator.AllocateBytes(&stream, /*byte_size=*/1); - EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); + RedzoneAllocator allocator(&stream, &se_allocator, opts, kRedzoneSize, + /*redzone_pattern=*/-1); + (void)allocator.AllocateBytes(/*byte_size=*/1); + EXPECT_REDZONE_OK(allocator.CheckRedzones()); } } // namespace diff --git a/tensorflow/stream_executor/scratch_allocator.cc b/tensorflow/stream_executor/scratch_allocator.cc index 8fc4c4c509c..520ee8a4208 100644 --- a/tensorflow/stream_executor/scratch_allocator.cc +++ b/tensorflow/stream_executor/scratch_allocator.cc @@ -22,18 +22,17 @@ namespace stream_executor { ScratchAllocator::~ScratchAllocator() {} -OneTimeScratchAllocator::OneTimeScratchAllocator() {} +OneTimeScratchAllocator::OneTimeScratchAllocator(Stream* stream) + : stream_(stream) {} OneTimeScratchAllocator::~OneTimeScratchAllocator() {} -int64 OneTimeScratchAllocator::GetMemoryLimitInBytes(Stream* stream) { - return -1; -} +int64 OneTimeScratchAllocator::GetMemoryLimitInBytes() { return -1; } port::StatusOr> OneTimeScratchAllocator::AllocateBytes( - Stream* stream, int64 byte_size) { + int64 byte_size) { CHECK(temporary_ == nullptr); SE_ASSIGN_OR_RETURN(temporary_, - stream->AllocateTemporaryArray(byte_size)); + stream_->AllocateTemporaryArray(byte_size)); return temporary_->device_memory(); } diff --git a/tensorflow/stream_executor/scratch_allocator.h b/tensorflow/stream_executor/scratch_allocator.h index 2aed2c44373..31278937fe4 100644 --- a/tensorflow/stream_executor/scratch_allocator.h +++ b/tensorflow/stream_executor/scratch_allocator.h @@ -45,14 +45,14 @@ class ScratchAllocator { // bytes. This information may be used to help select an algorithm. // // Returns values < 0 to indicate that there is no recommended limit. - virtual int64 GetMemoryLimitInBytes(Stream* stream) = 0; + virtual int64 GetMemoryLimitInBytes() = 0; // Returns an allocation on byte_size bytes for use in an operation on stream. // // This is a temporary allocation, and the caller is responsible for // deallocating at some known-safe point. See the class comment above. virtual port::StatusOr> AllocateBytes( - Stream* stream, int64 byte_size) = 0; + int64 byte_size) = 0; }; // Allocates a single temporary memory allocation -- this memory is deallocated @@ -64,14 +64,14 @@ class ScratchAllocator { // thread will request the scratch allocation). class OneTimeScratchAllocator : public ScratchAllocator { public: - OneTimeScratchAllocator(); + explicit OneTimeScratchAllocator(Stream* stream); ~OneTimeScratchAllocator() override; - int64 GetMemoryLimitInBytes(Stream* stream) override; - port::StatusOr> AllocateBytes(Stream* stream, - int64 byte_size) override; + int64 GetMemoryLimitInBytes() override; + port::StatusOr> AllocateBytes(int64 byte_size) override; private: std::unique_ptr> temporary_; + Stream* stream_; SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator); };