Remove duplicated code in conv and fft kernels
This commit is contained in:
parent
0d172940c1
commit
f7d29c94df
@ -619,19 +619,7 @@ template struct LaunchConv2DOp<CPUDevice, double>;
|
||||
|
||||
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||
&scratch_limit_in_mb)) {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
} else {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
}
|
||||
|
||||
// A dummy type to group forward convolution autotune results together.
|
||||
|
@ -48,52 +48,7 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
// A class to provide scratch-space allocator for Stream-Executor Cudnn
|
||||
// callback. TensorFlow is responsible for releasing the temporary buffers after
|
||||
// the kernel finishes.
|
||||
class DnnScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
virtual ~DnnScratchAllocator() {}
|
||||
DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
|
||||
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
|
||||
int64 byte_size) override {
|
||||
Tensor temporary_memory;
|
||||
if (byte_size < 0) {
|
||||
return se::port::Status{se::port::error::INVALID_ARGUMENT,
|
||||
"Requested negative byte size!"};
|
||||
}
|
||||
if (byte_size > memory_limit_) {
|
||||
return se::port::Status{se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Requested memory size (", byte_size,
|
||||
") exceeds the max memory limit (",
|
||||
memory_limit_, ").")};
|
||||
}
|
||||
AllocationAttributes allocation_attr;
|
||||
allocation_attr.retry_on_failure = false;
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
|
||||
AllocatorAttributes(), allocation_attr));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Failed to allocate the requested memory size (",
|
||||
byte_size, ").")};
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
total_byte_size_ += byte_size;
|
||||
return se::port::StatusOr<se::DeviceMemory<uint8>>(
|
||||
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
}
|
||||
int64 TotalByteSize() { return total_byte_size_; }
|
||||
|
||||
private:
|
||||
int64 memory_limit_;
|
||||
int64 total_byte_size_;
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
};
|
||||
using DnnScratchAllocator = GpuScratchAllocator;
|
||||
|
||||
// Encapsulate all the shape information that is used in both forward and
|
||||
// backward conv operations.
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
@ -400,20 +401,7 @@ class CufftScratchAllocator : public se::ScratchAllocator {
|
||||
|
||||
int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
|
||||
&scratch_limit_in_mb);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
} else {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
}
|
||||
|
||||
class FFTGPUBase : public FFTBase {
|
||||
|
Loading…
Reference in New Issue
Block a user