Adding 3d Pooling using latest MIOpen API

This commit is contained in:
Eugene Kuznetsov 2020-06-09 15:11:45 +00:00
parent 291125835e
commit 8681b1bf54
8 changed files with 360 additions and 200 deletions

View File

@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape); RunTest(x, x_init_value, y, y_shape);
} }
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) { TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1}); TensorShape y_shape({1, 1, 1, 1, 1});
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value); SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape); RunTest(x, x_init_value, y, y_shape);
} }
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) { TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1}); TensorShape x_shape({1, 2, 2, 1});
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape); RunTest(x, x_shape, y, y_shape);
} }
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) { TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1}); TensorShape y_shape({1, 1, 1, 1, 1});
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape); RunTest(x, x_shape, y, y_shape);
} }
#endif
TEST_F(NNGradTest, LRN) { TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1}); TensorShape x_shape({1, 1, 2, 1});

View File

@ -98,10 +98,25 @@ void DnnPooling3dOp<T>::Compute(OpKernelContext* context,
auto* stream = context->op_device_context()->stream(); auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
#if TENSORFLOW_USE_ROCM
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
bool status =
stream
->ThenPoolForward(pooling_desc, input_desc, input_data, output_desc,
&output_data, &scratch_allocator)
.ok();
#else
bool status = stream bool status = stream
->ThenPoolForward(pooling_desc, input_desc, input_data, ->ThenPoolForward(pooling_desc, input_desc, input_data,
output_desc, &output_data) output_desc, &output_data)
.ok(); .ok();
#endif
OP_REQUIRES(context, status, OP_REQUIRES(context, status,
errors::Internal("dnn PoolForward launch failed")); errors::Internal("dnn PoolForward launch failed"));
@ -225,12 +240,28 @@ void DnnPooling3dGradOp<T>::Compute(
auto* stream = context->op_device_context()->stream(); auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
#if TENSORFLOW_USE_ROCM
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
bool status = stream
->ThenPoolBackward(pooling_desc, orig_input_desc,
orig_input_data, orig_output_desc,
orig_output_data, output_backprop_data,
&input_backprop_data, &scratch_allocator)
.ok();
#else
bool status = bool status =
stream stream
->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data, ->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
orig_output_desc, orig_output_data, orig_output_desc, orig_output_data,
output_backprop_data, &input_backprop_data) output_backprop_data, &input_backprop_data)
.ok(); .ok();
#endif
OP_REQUIRES(context, status, OP_REQUIRES(context, status,
errors::Internal("dnn PoolBackward launch failed")); errors::Internal("dnn PoolBackward launch failed"));

View File

@ -1455,9 +1455,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testMaxPooling3DGradient(self): def testMaxPooling3DGradient(self):
if test.is_built_with_rocm():
self.skipTest('Pooling with 3D tensors is not supported in ROCm')
def forward(a): def forward(a):
r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME') r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME')
return r return r

View File

@ -2995,7 +2995,6 @@ cuda_py_test(
name = "pooling_ops_3d_test", name = "pooling_ops_3d_test",
size = "medium", size = "medium",
srcs = ["pooling_ops_3d_test.py"], srcs = ["pooling_ops_3d_test.py"],
tags = ["no_rocm"],
deps = [ deps = [
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",

View File

@ -219,8 +219,6 @@ class PoolingTest(test.TestCase):
strides=strides) strides=strides)
def testPool3D(self): def testPool3D(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
with self.session(use_gpu=test.is_gpu_available()): with self.session(use_gpu=test.is_gpu_available()):
for padding in ["SAME", "VALID"]: for padding in ["SAME", "VALID"]:
for pooling_type in ["MAX", "AVG"]: for pooling_type in ["MAX", "AVG"]:
@ -363,8 +361,6 @@ class PoolingTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testGradient3D(self): def testGradient3D(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
with self.session(use_gpu=test.is_gpu_available()): with self.session(use_gpu=test.is_gpu_available()):
for padding in ["SAME", "VALID"]: for padding in ["SAME", "VALID"]:
for pooling_type in ["AVG", "MAX"]: for pooling_type in ["AVG", "MAX"]:

View File

@ -488,8 +488,6 @@ class NNTest(PForTestCase):
self._test_loop_fn(loop_fn, 3) self._test_loop_fn(loop_fn, 3)
def test_max_pool3d(self): def test_max_pool3d(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
with backprop.GradientTape(persistent=True) as g: with backprop.GradientTape(persistent=True) as g:
x = random_ops.random_uniform([3, 3, 2, 12, 12, 3]) x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
g.watch(x) g.watch(x)

View File

@ -263,7 +263,8 @@ namespace wrap {
__macro(miopenFindConvolutionForwardAlgorithm) \ __macro(miopenFindConvolutionForwardAlgorithm) \
__macro(miopenCreateTensorDescriptor) \ __macro(miopenCreateTensorDescriptor) \
__macro(miopenDestroyTensorDescriptor) \ __macro(miopenDestroyTensorDescriptor) \
__macro(miopenSet2dPoolingDescriptor) \ __macro(miopenSetNdPoolingDescriptor) \
__macro(miopenSetPoolingIndexType) \
__macro(miopenSetLRNDescriptor) \ __macro(miopenSetLRNDescriptor) \
__macro(miopenLRNGetWorkSpaceSize) \ __macro(miopenLRNGetWorkSpaceSize) \
__macro(miopenCreateConvolutionDescriptor) \ __macro(miopenCreateConvolutionDescriptor) \
@ -290,7 +291,7 @@ namespace wrap {
__macro(miopenSetTensorDescriptor) \ __macro(miopenSetTensorDescriptor) \
__macro(miopenGetTensorDescriptorSize) \ __macro(miopenGetTensorDescriptorSize) \
__macro(miopenPoolingForward) \ __macro(miopenPoolingForward) \
__macro(miopenPoolingGetWorkSpaceSize) \ __macro(miopenPoolingGetWorkSpaceSizeV2 \
__macro(miopenPoolingBackward) \ __macro(miopenPoolingBackward) \
__macro(miopenLRNForward) \ __macro(miopenLRNForward) \
__macro(miopenLRNBackward) \ __macro(miopenLRNBackward) \
@ -605,6 +606,11 @@ MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
// swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set // swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false, tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
&use_immediate_mode_); &use_immediate_mode_);
bool enable_pooling_cache = false;
tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
&enable_pooling_cache);
if (enable_pooling_cache) m_pooling_cache_allowed = true;
} }
port::Status MIOpenSupport::Init() { port::Status MIOpenSupport::Init() {
@ -844,17 +850,19 @@ class ScopedPoolingDescriptor {
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>); &CheckedNarrowing<int64, int>);
if (nd != 2) { status = wrap::miopenSetNdPoolingDescriptor(
LOG(FATAL) << "miopen requires pooling dimensions be 2"
<< ToString(status);
}
status = wrap::miopenSet2dPoolingDescriptor(
handle_, handle_,
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? miopenPoolingMax ? miopenPoolingMax
: miopenPoolingAverage), : miopenPoolingAverage),
shape[0], shape[1], padding[0], padding[1], strides[0], strides[1]); nd, shape.data(), padding.data(), strides.data());
// Note: The index type has to be uint32 type for now because MIOpen
// API assumes all input indexes to be the same type. Since a tensor
// descriptor can only use int32 type, the index type here need to be
// aligned with the tensor index type of the (input) tensor descritptor
status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
if (status != miopenStatusSuccess) { if (status != miopenStatusSuccess) {
LOG(FATAL) << "could not set miopen pooling descriptor: " LOG(FATAL) << "could not set miopen pooling descriptor: "
<< ToString(status); << ToString(status);
@ -4009,10 +4017,94 @@ bool MIOpenSupport::DoPoolForward(
const DeviceMemory<double>& input_data, const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& output_dimensions, const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) { DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
LOG(ERROR) << "miopen does not support pooling for dobule type yet"; LOG(ERROR) << "miopen does not support pooling for double type yet";
return false; return false;
} }
bool PoolingWorkspaceDescriptor::IsSame(
const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
return dtype == _type &&
input_dims ==
input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
output_dims ==
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
op.mode() == pooling_dimensions.mode() &&
op.window() == pooling_dimensions.window() &&
op.padding() == pooling_dimensions.padding() &&
op.strides() == pooling_dimensions.strides();
}
bool PoolingWorkspaceCache::find(
const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
PoolingWorkspaceDescriptor*& pdesc) {
pdesc = 0;
auto it = cache.find(p);
if (it == cache.end()) {
return false;
}
if (!it->second.IsSame(input_dimensions, output_dimensions,
pooling_dimensions, _type)) {
return false;
}
pdesc = &it->second;
return true;
}
void PoolingWorkspaceCache::insert(
const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
hipStream_t hip_stream) {
PoolingWorkspaceDescriptor* desc = 0;
auto it = cache.find(p);
if (it != cache.end()) {
// replacing an entry with the same pointer but different attributes
// (if everything matches, the caller is expected to reuse the entry)
desc = &it->second;
hipStreamSynchronize(hip_stream);
memory_used -= desc->workspace_size;
} else {
cache[p] = PoolingWorkspaceDescriptor();
desc = &cache[p];
}
desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
desc->output_dims =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
desc->op = pooling_dimensions;
desc->dtype = _type;
desc->timestamp = timestamp;
timestamp++;
desc->workspace = std::move(workspace);
desc->workspace_size = wsp_size;
memory_used += wsp_size;
trim(hip_stream);
}
void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
if (memory_used < memory_budget && cache.size() < trim_size) return;
bool must_sync = true;
while (true) {
int new_size = cache.size() - (cache.size() >> 2);
std::vector<const void*> old_entries;
for (auto& x : cache)
if (x.second.timestamp + new_size < timestamp)
old_entries.push_back(x.first);
if (old_entries.empty()) break;
if (must_sync) hipStreamSynchronize(hip_stream);
must_sync = true;
for (auto x : old_entries) {
memory_used -= cache[x].workspace_size;
cache.erase(x);
}
if (memory_used < memory_budget || cache.size() < 10) break;
}
}
bool MIOpenSupport::DoPoolForward( bool MIOpenSupport::DoPoolForward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& input_dimensions,
@ -4020,7 +4112,6 @@ bool MIOpenSupport::DoPoolForward(
const dnn::BatchDescriptor& output_dimensions, const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) { DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream); auto miopen = miopen_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input. // Alpha is the scaling factor for input.
float alpha = 1.0; float alpha = 1.0;
// Beta is the scaling factor for output. // Beta is the scaling factor for output.
@ -4030,10 +4121,48 @@ bool MIOpenSupport::DoPoolForward(
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat}; ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions}; ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
bool do_backward = false;
uint8* workspace = 0;
size_t workspace_size = 0;
std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
if (m_pooling_cache_enabled) {
do_backward = true;
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
pooling_desc.handle(), dest_desc.handle(), &workspace_size);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
if (workspace_size != 0) {
PoolingWorkspaceDescriptor* pdesc = 0;
bool cache_hit =
m_pooling_cache_allowed &&
m_pooling_cache.find(input_data.opaque(), input_dimensions,
output_dimensions, pooling_dimensions,
miopenFloat, pdesc);
if (cache_hit) {
// reusing the same buffer
workspace = reinterpret_cast<uint8*>(
pdesc->workspace->mutable_device_memory()->opaque());
} else {
wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size)
.ConsumeValueOrDie();
workspace = reinterpret_cast<uint8*>(
wsp_mem->mutable_device_memory()->opaque());
m_pooling_cache.insert(input_data.opaque(), input_dimensions,
output_dimensions, pooling_dimensions,
miopenFloat, wsp_mem, workspace_size,
AsGpuStreamValue(stream));
}
}
}
auto status = wrap::miopenPoolingForward( auto status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(), input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
false, nullptr, 0); do_backward, workspace, workspace_size);
if (status != miopenStatusSuccess) { if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue forward pooling on stream: " LOG(ERROR) << "failed to enqueue forward pooling on stream: "
<< ToString(status); << ToString(status);
@ -4072,6 +4201,118 @@ bool MIOpenSupport::DoPoolForward(
return true; return true;
} }
template <class T>
bool MIOpenSupport::DoPoolBackwardImpl(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<T>& output_data, const DeviceMemory<T>& input_diff_data,
DeviceMemory<T>* output_diff_data, ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream);
if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
auto type =
std::is_same<T, float>::value
? miopenFloat
: (std::is_same<T, Eigen::half>::value ? miopenHalf
: (miopenDataType_t)-1);
ScopedTensorDescriptor src_desc{input_dimensions, type};
ScopedTensorDescriptor dest_desc{output_dimensions, type};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
uint8* workspace_ptr = 0;
DeviceMemory<uint8> workspace;
PoolingWorkspaceDescriptor* pdesc = 0;
size_t workspace_size_in_bytes = 0;
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
// Allocate the workspace.
if (workspace_size_in_bytes > 0) {
bool cache_hit = m_pooling_cache_allowed &&
m_pooling_cache.find(input_data.opaque(), input_dimensions,
output_dimensions, pooling_dimensions,
type, pdesc);
if (cache_hit) {
assert(pdesc != 0);
workspace_ptr = reinterpret_cast<uint8*>(
pdesc->workspace->mutable_device_memory()->opaque());
VLOG(1) << "Pooling cache hit";
} else {
VLOG(1) << "Pooling cache miss";
assert(workspace_allocator);
auto allocated =
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
DeviceMemory<uint8> dest2; // duplicated dest from forward:
int64 dest2_size = 0;
// miopen requires the strides and dims to be ordered as BDYX.
std::vector<int64> dims64 =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
// miopen does not use strides and must have 4D tensor.
// std::vector<int> dims(pooling_dimensions.ndims() + 2);
dest2_size = sizeof(T);
for (auto& x : dims64) dest2_size *= x;
if (dest2_size > 0) {
assert(workspace_allocator);
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
} else {
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward pooling";
}
status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
workspace.opaque(), workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to enqueue forward pooling (before backward) on stream: "
<< ToString(status);
return false;
}
workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
}
}
status = wrap::miopenPoolingBackward(
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
output_diff_data->opaque(), workspace_ptr);
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
}
bool MIOpenSupport::DoPoolBackward( bool MIOpenSupport::DoPoolBackward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& input_dimensions,
@ -4094,91 +4335,10 @@ bool MIOpenSupport::DoPoolBackward(
const DeviceMemory<float>& input_diff_data, const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data, DeviceMemory<float>* output_diff_data,
ScratchAllocator* workspace_allocator) { ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream); return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
// Alpha is the scaling factor for input. input_diff_data, output_diff_data,
float alpha = 1.0; workspace_allocator);
// Beta is the scaling factor for output.
float beta = 0.0;
ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
DeviceMemory<uint8> workspace;
size_t workspace_size_in_bytes = 0;
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
&workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
// Allocate the workspace.
if (workspace_size_in_bytes > 0) {
assert(workspace_allocator);
auto allocated =
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
}
DeviceMemory<uint8> dest2; // duplicated dest from forward:
int dest2_size = 0;
// miopen requires the strides and dims to be ordered as BDYX.
std::vector<int64> dims64 =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
// miopen does not use strides and must have 4D tensor.
std::vector<int> dims(4);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
if (dest2_size > 0) {
assert(workspace_allocator);
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
} else {
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward pooling";
}
status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
workspace.opaque(), workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to enqueue forward pooling (before backward) on stream: "
<< ToString(status);
return false;
}
status = wrap::miopenPoolingBackward(
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
output_diff_data->opaque(), workspace.opaque());
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
} }
bool MIOpenSupport::DoPoolBackward( bool MIOpenSupport::DoPoolBackward(
@ -4190,91 +4350,10 @@ bool MIOpenSupport::DoPoolBackward(
const DeviceMemory<Eigen::half>& input_diff_data, const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data, DeviceMemory<Eigen::half>* output_diff_data,
ScratchAllocator* workspace_allocator) { ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream); return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
// Alpha is the scaling factor for input. input_diff_data, output_diff_data,
float alpha = 1.0; workspace_allocator);
// Beta is the scaling factor for output.
float beta = 0.0;
ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
DeviceMemory<uint8> workspace;
size_t workspace_size_in_bytes = 0;
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
&workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
// Allocate the workspace.
if (workspace_size_in_bytes > 0) {
assert(workspace_allocator);
auto allocated =
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
}
DeviceMemory<uint8> dest2; // duplicated dest from forward:
int dest2_size = 0;
// miopen requires the strides and dims to be ordered as BDYX.
std::vector<int64> dims64 =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
// miopen does not use strides and must have 4D tensor.
std::vector<int> dims(4);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
if (dest2_size > 0) {
assert(workspace_allocator);
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
} else {
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward pooling";
}
status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
workspace.opaque(), workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to enqueue forward pooling (before backward) on stream: "
<< ToString(status);
return false;
}
status = wrap::miopenPoolingBackward(
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
output_diff_data->opaque(), workspace.opaque());
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
} }
bool MIOpenSupport::DoNormalizeWithDimensions( bool MIOpenSupport::DoNormalizeWithDimensions(

View File

@ -20,6 +20,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_ #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "rocm/include/miopen/miopen.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
@ -38,6 +39,39 @@ class MIOpenCTCLossDescriptor;
// Opaque and unique identifier for the MIOpen plugin. // Opaque and unique identifier for the MIOpen plugin.
extern const PluginId kMIOpenPlugin; extern const PluginId kMIOpenPlugin;
struct PoolingWorkspaceDescriptor {
std::vector<int64> input_dims;
std::vector<int64> output_dims;
dnn::PoolingDescriptor op;
int dtype;
uint64_t timestamp;
std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
size_t workspace_size;
bool IsSame(const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type);
};
struct PoolingWorkspaceCache {
std::map<const void*, PoolingWorkspaceDescriptor> cache;
const int trim_size = 1000;
const uint64_t memory_budget = 2e7;
uint64_t timestamp = 0;
uint64_t memory_used = 0;
bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
PoolingWorkspaceDescriptor*& pdesc);
void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
size_t wsp_size, hipStream_t hip_stream);
private:
void trim(hipStream_t hip_stream);
};
// miopen-library based DNN support. For details on overridden interface // miopen-library based DNN support. For details on overridden interface
// functions, see dnn.h. // functions, see dnn.h.
class MIOpenSupport : public dnn::DnnSupport { class MIOpenSupport : public dnn::DnnSupport {
@ -664,6 +698,10 @@ class MIOpenSupport : public dnn::DnnSupport {
// Provide access to the MIOpen handle. // Provide access to the MIOpen handle.
std::unique_ptr<class MIOpenAccess> miopen_; std::unique_ptr<class MIOpenAccess> miopen_;
PoolingWorkspaceCache m_pooling_cache;
bool m_pooling_cache_allowed = false;
bool m_pooling_cache_enabled = false;
template <class T, class U> template <class T, class U>
bool DoBatchNormalizationForwardImpl( bool DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType input_data_type, Stream* stream, dnn::DataType input_data_type,
@ -847,6 +885,36 @@ class MIOpenSupport : public dnn::DnnSupport {
ScratchAllocator* scratch_allocator, ScratchAllocator* scratch_allocator,
std::vector<dnn::ProfileResult>* out_algorithms); std::vector<dnn::ProfileResult>* out_algorithms);
port::Status DoCtcLossImpl(
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const MIOpenRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory);
port::Status DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const dnn::RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) override;
template <class T>
bool DoPoolBackwardImpl(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<T>& output_data,
const DeviceMemory<T>& input_diff_data,
DeviceMemory<T>* output_diff_data,
ScratchAllocator* workspace_allocator = nullptr);
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport); SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
}; };