Adding 3d Pooling using latest MIOpen API
This commit is contained in:
parent
291125835e
commit
8681b1bf54
@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
|||||||
RunTest(x, x_init_value, y, y_shape);
|
RunTest(x, x_init_value, y, y_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(rocm):
|
|
||||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
|
||||||
#ifndef TENSORFLOW_USE_ROCM
|
|
||||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||||
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
|||||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||||
RunTest(x, x_init_value, y, y_shape);
|
RunTest(x, x_init_value, y, y_shape);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||||
TensorShape x_shape({1, 2, 2, 1});
|
TensorShape x_shape({1, 2, 2, 1});
|
||||||
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
|||||||
RunTest(x, x_shape, y, y_shape);
|
RunTest(x, x_shape, y, y_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(rocm):
|
|
||||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
|
||||||
#ifndef TENSORFLOW_USE_ROCM
|
|
||||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||||
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
|||||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||||
RunTest(x, x_shape, y, y_shape);
|
RunTest(x, x_shape, y, y_shape);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
TEST_F(NNGradTest, LRN) {
|
TEST_F(NNGradTest, LRN) {
|
||||||
TensorShape x_shape({1, 1, 2, 1});
|
TensorShape x_shape({1, 1, 2, 1});
|
||||||
|
@ -98,10 +98,25 @@ void DnnPooling3dOp<T>::Compute(OpKernelContext* context,
|
|||||||
auto* stream = context->op_device_context()->stream();
|
auto* stream = context->op_device_context()->stream();
|
||||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
|
||||||
|
// default value is in bytes despite the name of the environment variable
|
||||||
|
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||||
|
);
|
||||||
|
|
||||||
|
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
|
||||||
|
bool status =
|
||||||
|
stream
|
||||||
|
->ThenPoolForward(pooling_desc, input_desc, input_data, output_desc,
|
||||||
|
&output_data, &scratch_allocator)
|
||||||
|
.ok();
|
||||||
|
#else
|
||||||
bool status = stream
|
bool status = stream
|
||||||
->ThenPoolForward(pooling_desc, input_desc, input_data,
|
->ThenPoolForward(pooling_desc, input_desc, input_data,
|
||||||
output_desc, &output_data)
|
output_desc, &output_data)
|
||||||
.ok();
|
.ok();
|
||||||
|
#endif
|
||||||
|
|
||||||
OP_REQUIRES(context, status,
|
OP_REQUIRES(context, status,
|
||||||
errors::Internal("dnn PoolForward launch failed"));
|
errors::Internal("dnn PoolForward launch failed"));
|
||||||
|
|
||||||
@ -225,12 +240,28 @@ void DnnPooling3dGradOp<T>::Compute(
|
|||||||
auto* stream = context->op_device_context()->stream();
|
auto* stream = context->op_device_context()->stream();
|
||||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
|
||||||
|
// default value is in bytes despite the name of the environment variable
|
||||||
|
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||||
|
);
|
||||||
|
|
||||||
|
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
|
||||||
|
bool status = stream
|
||||||
|
->ThenPoolBackward(pooling_desc, orig_input_desc,
|
||||||
|
orig_input_data, orig_output_desc,
|
||||||
|
orig_output_data, output_backprop_data,
|
||||||
|
&input_backprop_data, &scratch_allocator)
|
||||||
|
.ok();
|
||||||
|
#else
|
||||||
bool status =
|
bool status =
|
||||||
stream
|
stream
|
||||||
->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
|
->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
|
||||||
orig_output_desc, orig_output_data,
|
orig_output_desc, orig_output_data,
|
||||||
output_backprop_data, &input_backprop_data)
|
output_backprop_data, &input_backprop_data)
|
||||||
.ok();
|
.ok();
|
||||||
|
#endif
|
||||||
|
|
||||||
OP_REQUIRES(context, status,
|
OP_REQUIRES(context, status,
|
||||||
errors::Internal("dnn PoolBackward launch failed"));
|
errors::Internal("dnn PoolBackward launch failed"));
|
||||||
|
|
||||||
|
@ -1455,9 +1455,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testMaxPooling3DGradient(self):
|
def testMaxPooling3DGradient(self):
|
||||||
|
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest('Pooling with 3D tensors is not supported in ROCm')
|
|
||||||
|
|
||||||
def forward(a):
|
def forward(a):
|
||||||
r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME')
|
r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME')
|
||||||
return r
|
return r
|
||||||
|
@ -2995,7 +2995,6 @@ cuda_py_test(
|
|||||||
name = "pooling_ops_3d_test",
|
name = "pooling_ops_3d_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["pooling_ops_3d_test.py"],
|
srcs = ["pooling_ops_3d_test.py"],
|
||||||
tags = ["no_rocm"],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
@ -219,8 +219,6 @@ class PoolingTest(test.TestCase):
|
|||||||
strides=strides)
|
strides=strides)
|
||||||
|
|
||||||
def testPool3D(self):
|
def testPool3D(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
|
|
||||||
with self.session(use_gpu=test.is_gpu_available()):
|
with self.session(use_gpu=test.is_gpu_available()):
|
||||||
for padding in ["SAME", "VALID"]:
|
for padding in ["SAME", "VALID"]:
|
||||||
for pooling_type in ["MAX", "AVG"]:
|
for pooling_type in ["MAX", "AVG"]:
|
||||||
@ -363,8 +361,6 @@ class PoolingTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradient3D(self):
|
def testGradient3D(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
|
|
||||||
with self.session(use_gpu=test.is_gpu_available()):
|
with self.session(use_gpu=test.is_gpu_available()):
|
||||||
for padding in ["SAME", "VALID"]:
|
for padding in ["SAME", "VALID"]:
|
||||||
for pooling_type in ["AVG", "MAX"]:
|
for pooling_type in ["AVG", "MAX"]:
|
||||||
|
@ -488,8 +488,6 @@ class NNTest(PForTestCase):
|
|||||||
self._test_loop_fn(loop_fn, 3)
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
def test_max_pool3d(self):
|
def test_max_pool3d(self):
|
||||||
if test.is_built_with_rocm():
|
|
||||||
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
|
|
||||||
with backprop.GradientTape(persistent=True) as g:
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
|
x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
|
||||||
g.watch(x)
|
g.watch(x)
|
||||||
|
@ -263,7 +263,8 @@ namespace wrap {
|
|||||||
__macro(miopenFindConvolutionForwardAlgorithm) \
|
__macro(miopenFindConvolutionForwardAlgorithm) \
|
||||||
__macro(miopenCreateTensorDescriptor) \
|
__macro(miopenCreateTensorDescriptor) \
|
||||||
__macro(miopenDestroyTensorDescriptor) \
|
__macro(miopenDestroyTensorDescriptor) \
|
||||||
__macro(miopenSet2dPoolingDescriptor) \
|
__macro(miopenSetNdPoolingDescriptor) \
|
||||||
|
__macro(miopenSetPoolingIndexType) \
|
||||||
__macro(miopenSetLRNDescriptor) \
|
__macro(miopenSetLRNDescriptor) \
|
||||||
__macro(miopenLRNGetWorkSpaceSize) \
|
__macro(miopenLRNGetWorkSpaceSize) \
|
||||||
__macro(miopenCreateConvolutionDescriptor) \
|
__macro(miopenCreateConvolutionDescriptor) \
|
||||||
@ -290,7 +291,7 @@ namespace wrap {
|
|||||||
__macro(miopenSetTensorDescriptor) \
|
__macro(miopenSetTensorDescriptor) \
|
||||||
__macro(miopenGetTensorDescriptorSize) \
|
__macro(miopenGetTensorDescriptorSize) \
|
||||||
__macro(miopenPoolingForward) \
|
__macro(miopenPoolingForward) \
|
||||||
__macro(miopenPoolingGetWorkSpaceSize) \
|
__macro(miopenPoolingGetWorkSpaceSizeV2 \
|
||||||
__macro(miopenPoolingBackward) \
|
__macro(miopenPoolingBackward) \
|
||||||
__macro(miopenLRNForward) \
|
__macro(miopenLRNForward) \
|
||||||
__macro(miopenLRNBackward) \
|
__macro(miopenLRNBackward) \
|
||||||
@ -605,6 +606,11 @@ MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
|
|||||||
// swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
|
// swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
|
||||||
tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
|
tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
|
||||||
&use_immediate_mode_);
|
&use_immediate_mode_);
|
||||||
|
|
||||||
|
bool enable_pooling_cache = false;
|
||||||
|
tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
|
||||||
|
&enable_pooling_cache);
|
||||||
|
if (enable_pooling_cache) m_pooling_cache_allowed = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
port::Status MIOpenSupport::Init() {
|
port::Status MIOpenSupport::Init() {
|
||||||
@ -844,17 +850,19 @@ class ScopedPoolingDescriptor {
|
|||||||
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
|
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
|
||||||
&CheckedNarrowing<int64, int>);
|
&CheckedNarrowing<int64, int>);
|
||||||
|
|
||||||
if (nd != 2) {
|
status = wrap::miopenSetNdPoolingDescriptor(
|
||||||
LOG(FATAL) << "miopen requires pooling dimensions be 2"
|
|
||||||
<< ToString(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
status = wrap::miopenSet2dPoolingDescriptor(
|
|
||||||
handle_,
|
handle_,
|
||||||
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
|
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
|
||||||
? miopenPoolingMax
|
? miopenPoolingMax
|
||||||
: miopenPoolingAverage),
|
: miopenPoolingAverage),
|
||||||
shape[0], shape[1], padding[0], padding[1], strides[0], strides[1]);
|
nd, shape.data(), padding.data(), strides.data());
|
||||||
|
|
||||||
|
// Note: The index type has to be uint32 type for now because MIOpen
|
||||||
|
// API assumes all input indexes to be the same type. Since a tensor
|
||||||
|
// descriptor can only use int32 type, the index type here need to be
|
||||||
|
// aligned with the tensor index type of the (input) tensor descritptor
|
||||||
|
status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
if (status != miopenStatusSuccess) {
|
||||||
LOG(FATAL) << "could not set miopen pooling descriptor: "
|
LOG(FATAL) << "could not set miopen pooling descriptor: "
|
||||||
<< ToString(status);
|
<< ToString(status);
|
||||||
@ -4009,10 +4017,94 @@ bool MIOpenSupport::DoPoolForward(
|
|||||||
const DeviceMemory<double>& input_data,
|
const DeviceMemory<double>& input_data,
|
||||||
const dnn::BatchDescriptor& output_dimensions,
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
|
DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
|
||||||
LOG(ERROR) << "miopen does not support pooling for dobule type yet";
|
LOG(ERROR) << "miopen does not support pooling for double type yet";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool PoolingWorkspaceDescriptor::IsSame(
|
||||||
|
const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
|
||||||
|
return dtype == _type &&
|
||||||
|
input_dims ==
|
||||||
|
input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
|
||||||
|
output_dims ==
|
||||||
|
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
|
||||||
|
op.mode() == pooling_dimensions.mode() &&
|
||||||
|
op.window() == pooling_dimensions.window() &&
|
||||||
|
op.padding() == pooling_dimensions.padding() &&
|
||||||
|
op.strides() == pooling_dimensions.strides();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PoolingWorkspaceCache::find(
|
||||||
|
const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||||
|
PoolingWorkspaceDescriptor*& pdesc) {
|
||||||
|
pdesc = 0;
|
||||||
|
auto it = cache.find(p);
|
||||||
|
if (it == cache.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!it->second.IsSame(input_dimensions, output_dimensions,
|
||||||
|
pooling_dimensions, _type)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
pdesc = &it->second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PoolingWorkspaceCache::insert(
|
||||||
|
const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||||
|
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
|
||||||
|
hipStream_t hip_stream) {
|
||||||
|
PoolingWorkspaceDescriptor* desc = 0;
|
||||||
|
auto it = cache.find(p);
|
||||||
|
if (it != cache.end()) {
|
||||||
|
// replacing an entry with the same pointer but different attributes
|
||||||
|
// (if everything matches, the caller is expected to reuse the entry)
|
||||||
|
desc = &it->second;
|
||||||
|
hipStreamSynchronize(hip_stream);
|
||||||
|
memory_used -= desc->workspace_size;
|
||||||
|
} else {
|
||||||
|
cache[p] = PoolingWorkspaceDescriptor();
|
||||||
|
desc = &cache[p];
|
||||||
|
}
|
||||||
|
desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||||
|
desc->output_dims =
|
||||||
|
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||||
|
desc->op = pooling_dimensions;
|
||||||
|
desc->dtype = _type;
|
||||||
|
desc->timestamp = timestamp;
|
||||||
|
timestamp++;
|
||||||
|
desc->workspace = std::move(workspace);
|
||||||
|
desc->workspace_size = wsp_size;
|
||||||
|
memory_used += wsp_size;
|
||||||
|
trim(hip_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
|
||||||
|
if (memory_used < memory_budget && cache.size() < trim_size) return;
|
||||||
|
bool must_sync = true;
|
||||||
|
while (true) {
|
||||||
|
int new_size = cache.size() - (cache.size() >> 2);
|
||||||
|
std::vector<const void*> old_entries;
|
||||||
|
for (auto& x : cache)
|
||||||
|
if (x.second.timestamp + new_size < timestamp)
|
||||||
|
old_entries.push_back(x.first);
|
||||||
|
if (old_entries.empty()) break;
|
||||||
|
if (must_sync) hipStreamSynchronize(hip_stream);
|
||||||
|
must_sync = true;
|
||||||
|
for (auto x : old_entries) {
|
||||||
|
memory_used -= cache[x].workspace_size;
|
||||||
|
cache.erase(x);
|
||||||
|
}
|
||||||
|
if (memory_used < memory_budget || cache.size() < 10) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool MIOpenSupport::DoPoolForward(
|
bool MIOpenSupport::DoPoolForward(
|
||||||
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
||||||
const dnn::BatchDescriptor& input_dimensions,
|
const dnn::BatchDescriptor& input_dimensions,
|
||||||
@ -4020,7 +4112,6 @@ bool MIOpenSupport::DoPoolForward(
|
|||||||
const dnn::BatchDescriptor& output_dimensions,
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
|
DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
|
||||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||||
|
|
||||||
// Alpha is the scaling factor for input.
|
// Alpha is the scaling factor for input.
|
||||||
float alpha = 1.0;
|
float alpha = 1.0;
|
||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
@ -4030,10 +4121,48 @@ bool MIOpenSupport::DoPoolForward(
|
|||||||
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
|
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
|
||||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
||||||
|
|
||||||
|
bool do_backward = false;
|
||||||
|
uint8* workspace = 0;
|
||||||
|
size_t workspace_size = 0;
|
||||||
|
std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
|
||||||
|
if (m_pooling_cache_enabled) {
|
||||||
|
do_backward = true;
|
||||||
|
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
|
||||||
|
pooling_desc.handle(), dest_desc.handle(), &workspace_size);
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "failed to obtain workspace size for backward pooling on stream: "
|
||||||
|
<< ToString(status);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (workspace_size != 0) {
|
||||||
|
PoolingWorkspaceDescriptor* pdesc = 0;
|
||||||
|
bool cache_hit =
|
||||||
|
m_pooling_cache_allowed &&
|
||||||
|
m_pooling_cache.find(input_data.opaque(), input_dimensions,
|
||||||
|
output_dimensions, pooling_dimensions,
|
||||||
|
miopenFloat, pdesc);
|
||||||
|
if (cache_hit) {
|
||||||
|
// reusing the same buffer
|
||||||
|
workspace = reinterpret_cast<uint8*>(
|
||||||
|
pdesc->workspace->mutable_device_memory()->opaque());
|
||||||
|
} else {
|
||||||
|
wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size)
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
workspace = reinterpret_cast<uint8*>(
|
||||||
|
wsp_mem->mutable_device_memory()->opaque());
|
||||||
|
m_pooling_cache.insert(input_data.opaque(), input_dimensions,
|
||||||
|
output_dimensions, pooling_dimensions,
|
||||||
|
miopenFloat, wsp_mem, workspace_size,
|
||||||
|
AsGpuStreamValue(stream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto status = wrap::miopenPoolingForward(
|
auto status = wrap::miopenPoolingForward(
|
||||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
||||||
input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
|
input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
|
||||||
false, nullptr, 0);
|
do_backward, workspace, workspace_size);
|
||||||
if (status != miopenStatusSuccess) {
|
if (status != miopenStatusSuccess) {
|
||||||
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
|
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
|
||||||
<< ToString(status);
|
<< ToString(status);
|
||||||
@ -4072,6 +4201,118 @@ bool MIOpenSupport::DoPoolForward(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool MIOpenSupport::DoPoolBackwardImpl(
|
||||||
|
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
||||||
|
const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const DeviceMemory<T>& input_data,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const DeviceMemory<T>& output_data, const DeviceMemory<T>& input_diff_data,
|
||||||
|
DeviceMemory<T>* output_diff_data, ScratchAllocator* workspace_allocator) {
|
||||||
|
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||||
|
if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
|
||||||
|
// Alpha is the scaling factor for input.
|
||||||
|
float alpha = 1.0;
|
||||||
|
// Beta is the scaling factor for output.
|
||||||
|
float beta = 0.0;
|
||||||
|
|
||||||
|
auto type =
|
||||||
|
std::is_same<T, float>::value
|
||||||
|
? miopenFloat
|
||||||
|
: (std::is_same<T, Eigen::half>::value ? miopenHalf
|
||||||
|
: (miopenDataType_t)-1);
|
||||||
|
|
||||||
|
ScopedTensorDescriptor src_desc{input_dimensions, type};
|
||||||
|
ScopedTensorDescriptor dest_desc{output_dimensions, type};
|
||||||
|
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
||||||
|
|
||||||
|
uint8* workspace_ptr = 0;
|
||||||
|
DeviceMemory<uint8> workspace;
|
||||||
|
PoolingWorkspaceDescriptor* pdesc = 0;
|
||||||
|
|
||||||
|
size_t workspace_size_in_bytes = 0;
|
||||||
|
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
|
||||||
|
pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "failed to obtain workspace size for backward pooling on stream: "
|
||||||
|
<< ToString(status);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the workspace.
|
||||||
|
if (workspace_size_in_bytes > 0) {
|
||||||
|
bool cache_hit = m_pooling_cache_allowed &&
|
||||||
|
m_pooling_cache.find(input_data.opaque(), input_dimensions,
|
||||||
|
output_dimensions, pooling_dimensions,
|
||||||
|
type, pdesc);
|
||||||
|
if (cache_hit) {
|
||||||
|
assert(pdesc != 0);
|
||||||
|
workspace_ptr = reinterpret_cast<uint8*>(
|
||||||
|
pdesc->workspace->mutable_device_memory()->opaque());
|
||||||
|
VLOG(1) << "Pooling cache hit";
|
||||||
|
} else {
|
||||||
|
VLOG(1) << "Pooling cache miss";
|
||||||
|
assert(workspace_allocator);
|
||||||
|
auto allocated =
|
||||||
|
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||||
|
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
|
||||||
|
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
DeviceMemory<uint8> dest2; // duplicated dest from forward:
|
||||||
|
int64 dest2_size = 0;
|
||||||
|
|
||||||
|
// miopen requires the strides and dims to be ordered as BDYX.
|
||||||
|
std::vector<int64> dims64 =
|
||||||
|
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||||
|
// miopen does not use strides and must have 4D tensor.
|
||||||
|
// std::vector<int> dims(pooling_dimensions.ndims() + 2);
|
||||||
|
|
||||||
|
dest2_size = sizeof(T);
|
||||||
|
for (auto& x : dims64) dest2_size *= x;
|
||||||
|
|
||||||
|
if (dest2_size > 0) {
|
||||||
|
assert(workspace_allocator);
|
||||||
|
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
|
||||||
|
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
|
||||||
|
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
||||||
|
"backward pooling";
|
||||||
|
}
|
||||||
|
|
||||||
|
status = wrap::miopenPoolingForward(
|
||||||
|
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
||||||
|
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
|
||||||
|
workspace.opaque(), workspace_size_in_bytes);
|
||||||
|
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "failed to enqueue forward pooling (before backward) on stream: "
|
||||||
|
<< ToString(status);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status = wrap::miopenPoolingBackward(
|
||||||
|
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
|
||||||
|
output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
|
||||||
|
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
|
||||||
|
output_diff_data->opaque(), workspace_ptr);
|
||||||
|
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
|
||||||
|
<< ToString(status);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool MIOpenSupport::DoPoolBackward(
|
bool MIOpenSupport::DoPoolBackward(
|
||||||
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
||||||
const dnn::BatchDescriptor& input_dimensions,
|
const dnn::BatchDescriptor& input_dimensions,
|
||||||
@ -4094,91 +4335,10 @@ bool MIOpenSupport::DoPoolBackward(
|
|||||||
const DeviceMemory<float>& input_diff_data,
|
const DeviceMemory<float>& input_diff_data,
|
||||||
DeviceMemory<float>* output_diff_data,
|
DeviceMemory<float>* output_diff_data,
|
||||||
ScratchAllocator* workspace_allocator) {
|
ScratchAllocator* workspace_allocator) {
|
||||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
|
||||||
|
input_data, output_dimensions, output_data,
|
||||||
// Alpha is the scaling factor for input.
|
input_diff_data, output_diff_data,
|
||||||
float alpha = 1.0;
|
workspace_allocator);
|
||||||
// Beta is the scaling factor for output.
|
|
||||||
float beta = 0.0;
|
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
|
|
||||||
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
|
|
||||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
|
||||||
|
|
||||||
DeviceMemory<uint8> workspace;
|
|
||||||
size_t workspace_size_in_bytes = 0;
|
|
||||||
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
|
|
||||||
&workspace_size_in_bytes);
|
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
|
||||||
LOG(ERROR)
|
|
||||||
<< "failed to obtain workspace size for backward pooling on stream: "
|
|
||||||
<< ToString(status);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate the workspace.
|
|
||||||
if (workspace_size_in_bytes > 0) {
|
|
||||||
assert(workspace_allocator);
|
|
||||||
auto allocated =
|
|
||||||
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
|
||||||
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
|
|
||||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceMemory<uint8> dest2; // duplicated dest from forward:
|
|
||||||
int dest2_size = 0;
|
|
||||||
|
|
||||||
// miopen requires the strides and dims to be ordered as BDYX.
|
|
||||||
std::vector<int64> dims64 =
|
|
||||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
|
||||||
|
|
||||||
// miopen does not use strides and must have 4D tensor.
|
|
||||||
std::vector<int> dims(4);
|
|
||||||
|
|
||||||
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
|
|
||||||
&CheckedNarrowing<int64, int>);
|
|
||||||
|
|
||||||
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
|
|
||||||
|
|
||||||
if (dest2_size > 0) {
|
|
||||||
assert(workspace_allocator);
|
|
||||||
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
|
|
||||||
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
|
|
||||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
|
||||||
"backward pooling";
|
|
||||||
}
|
|
||||||
|
|
||||||
status = wrap::miopenPoolingForward(
|
|
||||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
|
||||||
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
|
|
||||||
workspace.opaque(), workspace_size_in_bytes);
|
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
|
||||||
LOG(ERROR)
|
|
||||||
<< "failed to enqueue forward pooling (before backward) on stream: "
|
|
||||||
<< ToString(status);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
status = wrap::miopenPoolingBackward(
|
|
||||||
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
|
|
||||||
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
|
|
||||||
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
|
|
||||||
output_diff_data->opaque(), workspace.opaque());
|
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
|
||||||
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
|
|
||||||
<< ToString(status);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MIOpenSupport::DoPoolBackward(
|
bool MIOpenSupport::DoPoolBackward(
|
||||||
@ -4190,91 +4350,10 @@ bool MIOpenSupport::DoPoolBackward(
|
|||||||
const DeviceMemory<Eigen::half>& input_diff_data,
|
const DeviceMemory<Eigen::half>& input_diff_data,
|
||||||
DeviceMemory<Eigen::half>* output_diff_data,
|
DeviceMemory<Eigen::half>* output_diff_data,
|
||||||
ScratchAllocator* workspace_allocator) {
|
ScratchAllocator* workspace_allocator) {
|
||||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
|
||||||
|
input_data, output_dimensions, output_data,
|
||||||
// Alpha is the scaling factor for input.
|
input_diff_data, output_diff_data,
|
||||||
float alpha = 1.0;
|
workspace_allocator);
|
||||||
// Beta is the scaling factor for output.
|
|
||||||
float beta = 0.0;
|
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
|
|
||||||
ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
|
|
||||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
|
||||||
|
|
||||||
DeviceMemory<uint8> workspace;
|
|
||||||
size_t workspace_size_in_bytes = 0;
|
|
||||||
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
|
|
||||||
&workspace_size_in_bytes);
|
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
|
||||||
LOG(ERROR)
|
|
||||||
<< "failed to obtain workspace size for backward pooling on stream: "
|
|
||||||
<< ToString(status);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate the workspace.
|
|
||||||
if (workspace_size_in_bytes > 0) {
|
|
||||||
assert(workspace_allocator);
|
|
||||||
auto allocated =
|
|
||||||
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
|
||||||
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
|
|
||||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceMemory<uint8> dest2; // duplicated dest from forward:
|
|
||||||
int dest2_size = 0;
|
|
||||||
|
|
||||||
// miopen requires the strides and dims to be ordered as BDYX.
|
|
||||||
std::vector<int64> dims64 =
|
|
||||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
|
||||||
|
|
||||||
// miopen does not use strides and must have 4D tensor.
|
|
||||||
std::vector<int> dims(4);
|
|
||||||
|
|
||||||
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
|
|
||||||
&CheckedNarrowing<int64, int>);
|
|
||||||
|
|
||||||
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
|
|
||||||
|
|
||||||
if (dest2_size > 0) {
|
|
||||||
assert(workspace_allocator);
|
|
||||||
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
|
|
||||||
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
|
|
||||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
|
||||||
"backward pooling";
|
|
||||||
}
|
|
||||||
|
|
||||||
status = wrap::miopenPoolingForward(
|
|
||||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
|
||||||
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
|
|
||||||
workspace.opaque(), workspace_size_in_bytes);
|
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
|
||||||
LOG(ERROR)
|
|
||||||
<< "failed to enqueue forward pooling (before backward) on stream: "
|
|
||||||
<< ToString(status);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
status = wrap::miopenPoolingBackward(
|
|
||||||
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
|
|
||||||
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
|
|
||||||
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
|
|
||||||
output_diff_data->opaque(), workspace.opaque());
|
|
||||||
|
|
||||||
if (status != miopenStatusSuccess) {
|
|
||||||
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
|
|
||||||
<< ToString(status);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MIOpenSupport::DoNormalizeWithDimensions(
|
bool MIOpenSupport::DoNormalizeWithDimensions(
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
||||||
|
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "rocm/include/miopen/miopen.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/stream_executor/dnn.h"
|
#include "tensorflow/stream_executor/dnn.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
@ -38,6 +39,39 @@ class MIOpenCTCLossDescriptor;
|
|||||||
// Opaque and unique identifier for the MIOpen plugin.
|
// Opaque and unique identifier for the MIOpen plugin.
|
||||||
extern const PluginId kMIOpenPlugin;
|
extern const PluginId kMIOpenPlugin;
|
||||||
|
|
||||||
|
struct PoolingWorkspaceDescriptor {
|
||||||
|
std::vector<int64> input_dims;
|
||||||
|
std::vector<int64> output_dims;
|
||||||
|
dnn::PoolingDescriptor op;
|
||||||
|
int dtype;
|
||||||
|
uint64_t timestamp;
|
||||||
|
std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
|
||||||
|
size_t workspace_size;
|
||||||
|
bool IsSame(const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions, int _type);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PoolingWorkspaceCache {
|
||||||
|
std::map<const void*, PoolingWorkspaceDescriptor> cache;
|
||||||
|
const int trim_size = 1000;
|
||||||
|
const uint64_t memory_budget = 2e7;
|
||||||
|
uint64_t timestamp = 0;
|
||||||
|
uint64_t memory_used = 0;
|
||||||
|
bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||||
|
PoolingWorkspaceDescriptor*& pdesc);
|
||||||
|
void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||||
|
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
|
||||||
|
size_t wsp_size, hipStream_t hip_stream);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void trim(hipStream_t hip_stream);
|
||||||
|
};
|
||||||
|
|
||||||
// miopen-library based DNN support. For details on overridden interface
|
// miopen-library based DNN support. For details on overridden interface
|
||||||
// functions, see dnn.h.
|
// functions, see dnn.h.
|
||||||
class MIOpenSupport : public dnn::DnnSupport {
|
class MIOpenSupport : public dnn::DnnSupport {
|
||||||
@ -664,6 +698,10 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||||||
// Provide access to the MIOpen handle.
|
// Provide access to the MIOpen handle.
|
||||||
std::unique_ptr<class MIOpenAccess> miopen_;
|
std::unique_ptr<class MIOpenAccess> miopen_;
|
||||||
|
|
||||||
|
PoolingWorkspaceCache m_pooling_cache;
|
||||||
|
bool m_pooling_cache_allowed = false;
|
||||||
|
bool m_pooling_cache_enabled = false;
|
||||||
|
|
||||||
template <class T, class U>
|
template <class T, class U>
|
||||||
bool DoBatchNormalizationForwardImpl(
|
bool DoBatchNormalizationForwardImpl(
|
||||||
Stream* stream, dnn::DataType input_data_type,
|
Stream* stream, dnn::DataType input_data_type,
|
||||||
@ -847,6 +885,36 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||||||
ScratchAllocator* scratch_allocator,
|
ScratchAllocator* scratch_allocator,
|
||||||
std::vector<dnn::ProfileResult>* out_algorithms);
|
std::vector<dnn::ProfileResult>* out_algorithms);
|
||||||
|
|
||||||
|
port::Status DoCtcLossImpl(
|
||||||
|
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||||
|
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||||
|
DeviceMemory<uint8> scratch_memory);
|
||||||
|
|
||||||
|
port::Status DoPrepareForCtcLoss(
|
||||||
|
Stream* stream, dnn::DataType element_type,
|
||||||
|
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||||
|
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||||
|
absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data,
|
||||||
|
ScratchAllocator* scratch_allocator,
|
||||||
|
DeviceMemory<uint8>* scratch_memory) override;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool DoPoolBackwardImpl(Stream* stream,
|
||||||
|
const dnn::PoolingDescriptor& pooling_dimensions,
|
||||||
|
const dnn::BatchDescriptor& input_dimensions,
|
||||||
|
const DeviceMemory<T>& input_data,
|
||||||
|
const dnn::BatchDescriptor& output_dimensions,
|
||||||
|
const DeviceMemory<T>& output_data,
|
||||||
|
const DeviceMemory<T>& input_diff_data,
|
||||||
|
DeviceMemory<T>* output_diff_data,
|
||||||
|
ScratchAllocator* workspace_allocator = nullptr);
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
|
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user