Use DnnScratchAllocator
This commit is contained in:
parent
0b9feecc74
commit
4a89f04615
@ -367,16 +367,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_scratch_allocator",
|
||||
srcs = ["util/cudnn_scratch_allocator.cc"],
|
||||
hdrs = ["util/cudnn_scratch_allocator.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/stream_executor:scratch_allocator",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "util_port_hdrs",
|
||||
srcs = [
|
||||
|
@ -2297,8 +2297,8 @@ tf_kernel_library(
|
||||
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
||||
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
||||
] + if_cuda([
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core:cudnn_scratch_allocator",
|
||||
":gpu_utils",
|
||||
":conv_ops_gpu_hdrs",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -30,9 +30,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/core/util/cudnn_scratch_allocator.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
@ -323,21 +323,19 @@ class CTCLossOpGPU : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, grads_desc_s.status());
|
||||
grads_desc = grads_desc_s.ConsumeValueOrDie();
|
||||
|
||||
absl::Span<const int32> labels_data;
|
||||
absl::Span<const int32> labels_lengths_data;
|
||||
absl::Span<const int32> input_lengths_data;
|
||||
labels_data = absl::Span<const int32>(
|
||||
labels_values->flat<int32>().data(), num_indices);
|
||||
labels_lengths_data = absl::Span<const int32>(
|
||||
labels_lengths.data(), batch_size);
|
||||
input_lengths_data = absl::Span<const int32>(
|
||||
seq_len->flat<int32>().data(), batch_size);
|
||||
absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
|
||||
num_indices);
|
||||
absl::Span<const int32> labels_lengths_data(labels_lengths.data(),
|
||||
batch_size);
|
||||
absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(),
|
||||
batch_size);
|
||||
|
||||
auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
|
||||
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
|
||||
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
|
||||
|
||||
CudnnAllocatorInTemp workspace_allocator(ctx);
|
||||
// Set the memory limitation to 4GB for workspace memory.
|
||||
DnnScratchAllocator workspace_allocator(1LL << 32, ctx);
|
||||
|
||||
Stream* stream = ctx->op_device_context()->stream();
|
||||
bool cudnn_launch_status =
|
||||
|
@ -42,7 +42,7 @@ class CudnnAllocatorInTemp : public ScratchAllocator {
|
||||
OpKernelContext* context_; // not owned
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnAllocatorInTemp);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnAllocatorInTemp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user