Remove duplicated code in conv and fft kernels

This commit is contained in:
Ben Barsdell 2020-07-06 21:30:27 +10:00
parent 0d172940c1
commit f7d29c94df
3 changed files with 4 additions and 73 deletions

View File

@ -619,19 +619,7 @@ template struct LaunchConv2DOp<CPUDevice, double>;
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
if (strings::safe_strto64(workspace_limit_in_mb_str,
&scratch_limit_in_mb)) {
return scratch_limit_in_mb * (1 << 20);
} else {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
}
}
return default_value_in_bytes;
return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
}
// A dummy type to group forward convolution autotune results together.

View File

@ -48,52 +48,7 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
// A class to provide scratch-space allocator for Stream-Executor Cudnn
// callback. TensorFlow is responsible for releasing the temporary buffers after
// the kernel finishes.
class DnnScratchAllocator : public se::ScratchAllocator {
public:
virtual ~DnnScratchAllocator() {}
DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
int64 byte_size) override {
Tensor temporary_memory;
if (byte_size < 0) {
return se::port::Status{se::port::error::INVALID_ARGUMENT,
"Requested negative byte size!"};
}
if (byte_size > memory_limit_) {
return se::port::Status{se::port::error::UNAVAILABLE,
absl::StrCat("Requested memory size (", byte_size,
") exceeds the max memory limit (",
memory_limit_, ").")};
}
AllocationAttributes allocation_attr;
allocation_attr.retry_on_failure = false;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
AllocatorAttributes(), allocation_attr));
if (!allocation_status.ok()) {
return se::port::Status{
se::port::error::UNAVAILABLE,
absl::StrCat("Failed to allocate the requested memory size (",
byte_size, ").")};
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
return se::port::StatusOr<se::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
int64 TotalByteSize() { return total_byte_size_; }
private:
int64 memory_limit_;
int64 total_byte_size_;
OpKernelContext* context_;
std::vector<Tensor> allocated_tensors_;
};
using DnnScratchAllocator = GpuScratchAllocator;
// Encapsulate all the shape information that is used in both forward and
// backward conv operations.

View File

@ -31,6 +31,7 @@ limitations under the License.
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -400,20 +401,7 @@ class CufftScratchAllocator : public se::ScratchAllocator {
int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
&scratch_limit_in_mb);
if (!status.ok()) {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
} else {
return scratch_limit_in_mb * (1 << 20);
}
}
return default_value_in_bytes;
return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
}
class FFTGPUBase : public FFTBase {