Merge pull request #40320 from ROCmSoftwarePlatform:google-upstream-eugene-3dpool

PiperOrigin-RevId: 317868942
Change-Id: I47a6de9d0c270a3d38abacd748fa57903e6b0673
This commit is contained in:
TensorFlower Gardener 2020-06-23 08:36:07 -07:00
commit 048ff6ab46
8 changed files with 504 additions and 331 deletions

View File

@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape);
}
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1});
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape);
}
#endif
TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});

View File

@ -98,10 +98,25 @@ void DnnPooling3dOp<T>::Compute(OpKernelContext* context,
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
#if TENSORFLOW_USE_ROCM
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
bool status =
stream
->ThenPoolForward(pooling_desc, input_desc, input_data, output_desc,
&output_data, &scratch_allocator)
.ok();
#else
bool status = stream
->ThenPoolForward(pooling_desc, input_desc, input_data,
output_desc, &output_data)
.ok();
#endif
OP_REQUIRES(context, status,
errors::Internal("dnn PoolForward launch failed"));
@ -225,12 +240,28 @@ void DnnPooling3dGradOp<T>::Compute(
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
#if TENSORFLOW_USE_ROCM
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
// default value is in bytes despite the name of the environment variable
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
bool status = stream
->ThenPoolBackward(pooling_desc, orig_input_desc,
orig_input_data, orig_output_desc,
orig_output_data, output_backprop_data,
&input_backprop_data, &scratch_allocator)
.ok();
#else
bool status =
stream
->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
orig_output_desc, orig_output_data,
output_backprop_data, &input_backprop_data)
.ok();
#endif
OP_REQUIRES(context, status,
errors::Internal("dnn PoolBackward launch failed"));

View File

@ -84,8 +84,9 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
tf_y = tf_g1 * tf_g2 * tf_g3
tf_grad = gradients.gradients(tf_y, [tf_var])[0]
tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
tf_dense_grad = math_ops.unsorted_segment_sum(tf_grad.values,
tf_grad.indices,
tf_grad.dense_shape[0])
self.assertAllClose(grad, self.evaluate(tf_dense_grad))
@ -127,8 +128,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(grads_and_vars[0][0], 1.0)
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
@parameterized.named_parameters(
[('Function', def_function.function),
@parameterized.named_parameters([('Function', def_function.function),
('NoFunction', lambda f: f)])
def testNoOpBehaviorConsistent(self, decorator):
@ -195,8 +195,10 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@custom_gradient.custom_gradient
def identity(x):
def grad(_):
return [] # This return value is wrong!
return x, grad
x = variables.Variable(1.0)
@ -234,8 +236,10 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@custom_gradient.custom_gradient
def f(x):
def grad(_):
raise RuntimeError('x')
return x, grad
# TODO(apassos) raise the right error here
@ -337,7 +341,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
x = constant_op.constant(2.0)
with backprop.GradientTape() as t:
t.watch(x)
y = x*x
y = x * x
self.assertEqual(t.gradient([x, y], x).numpy(), 5.0)
def testTapeNoOpGradientWithMultiTargetAllSource(self):
@ -441,9 +445,11 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(t.gradient(loss, v), 2.0)
def testPythonMax(self):
x = [resource_variable_ops.ResourceVariable(2.),
x = [
resource_variable_ops.ResourceVariable(2.),
resource_variable_ops.ResourceVariable(3.),
resource_variable_ops.ResourceVariable(5.)]
resource_variable_ops.ResourceVariable(5.)
]
with backprop.GradientTape() as t:
f = max(x)
grad = t.gradient(f, x)
@ -538,8 +544,8 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with backprop.GradientTape() as tape2:
tape1.watch(x1)
tape2.watch([x1, x2])
y = x1 ** 3
z = x2 ** 2
y = x1**3
z = x2**2
dy, dz = tape2.gradient([y, z], [x1, x2])
d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
@ -602,6 +608,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.assert_no_new_tensors
def testArgmax(self):
def argmax(x):
i = math_ops.argmax(x)
return array_ops.stop_gradient(i)
@ -612,6 +619,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.run_gpu_only
@test_util.assert_no_new_tensors
def testGPU(self):
def fn(x):
with context.device('/gpu:0'):
b = constant_op.constant(2.0)
@ -634,8 +642,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with context.device('gpu:0'):
return v.read_value()
self.assertEqual(
backprop.implicit_grad(f)()[0][0].cpu().numpy(), 1.0)
self.assertEqual(backprop.implicit_grad(f)()[0][0].cpu().numpy(), 1.0)
@test_util.assert_no_new_tensors
def testCPU(self):
@ -651,6 +658,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.run_gpu_only
@test_util.assert_no_new_tensors
def testTensorCopyGPU2CPU2GPU(self):
def f(a, b):
return a.cpu() + b.cpu()
@ -675,8 +683,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.assert_no_new_tensors
def testUnconnectedNone(self):
v = resource_variable_ops.ResourceVariable(
1.0, name='testUnconnectedNone')
v = resource_variable_ops.ResourceVariable(1.0, name='testUnconnectedNone')
def f():
v.read_value()
@ -690,9 +697,9 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with g:
x = constant_op.constant(3.0)
g.watch(x)
y = 2*x
y = 2 * x
with g:
z = 2*y
z = 2 * y
grad = g.gradient(target=z, sources=[x])
self.assertEqual(self.evaluate(grad), [4.0])
@ -736,12 +743,20 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(g.gradient(y, x1)), [1.0])
self.assertEqual(self.evaluate(g.gradient(y, (x1,))), (1.0,))
self.assertEqual(self.evaluate(g.gradient(y, (x1, x2))), (1.0, 2.0))
self.assertEqual(self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])),
[(1.0, 2.0), (2.0, 3.0)])
self.assertEqual(self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
self.assertEqual(
self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])), [(1.0, 2.0),
(2.0, 3.0)])
self.assertEqual(
self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
(1.0, 2.0, [1.0, 3.0]))
self.assertEqual(self.evaluate(g.gradient(y, [x1, {'x2': x2, 'x3': x3}])),
[1.0, {'x2': 2.0, 'x3': 3.0}])
self.assertEqual(
self.evaluate(g.gradient(y, [x1, {
'x2': x2,
'x3': x3
}])), [1.0, {
'x2': 2.0,
'x3': 3.0
}])
@test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes
@ -846,13 +861,13 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
g.watch(x)
y = x ** 3 # y := x^3
y = x**3 # y := x^3
dy_dx = g.gradient(y, x) # dy/dx := 3x^2
d2y_dx2 = g.gradient(dy_dx, x) # d2y/dx2 := 6x
d3y_dx3 = g.gradient(d2y_dx2, x) # d3y/dx3 := 6
x = 3
self.assertEqual(self.evaluate(y), x ** 3)
self.assertEqual(self.evaluate(dy_dx), 3 * x ** 2)
self.assertEqual(self.evaluate(y), x**3)
self.assertEqual(self.evaluate(dy_dx), 3 * x**2)
self.assertEqual(self.evaluate(d2y_dx2), 6 * x)
self.assertEqual(self.evaluate(d3y_dx3), 6)
del g
@ -973,19 +988,17 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
x = constant_op.constant(1.)
with backprop.GradientTape() as g:
g.watch(x)
tape_lib.record_operation(
'InvalidBackprop',
[y],
[x],
lambda dy: [])
with self.assertRaisesRegexp(
errors_impl.InternalError, 'InvalidBackprop.*too few gradients'):
tape_lib.record_operation('InvalidBackprop', [y], [x], lambda dy: [])
with self.assertRaisesRegexp(errors_impl.InternalError,
'InvalidBackprop.*too few gradients'):
g.gradient(y, x)
@test_util.assert_no_new_tensors
def testEmptyParamsForValueAndGradFunction(self):
def fn(a, b):
return a * b
val_and_grads_fn = backprop.val_and_grad_function(fn)
x = 2.0
@ -997,8 +1010,10 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.assert_no_new_tensors
def testNonEmptyParamsForValueAndGradFunction(self):
def fn(a, b):
return a * b
val_and_grad_fn = backprop.val_and_grad_function(fn, params=[1])
x = 2.0
@ -1046,9 +1061,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def mul(x):
return math_ops._mul_dispatch(x, x) # pylint: disable=protected-access
self.assertAllEqual(
backprop.gradients_function(mul)(3.0)[0].numpy(),
6.0)
self.assertAllEqual(backprop.gradients_function(mul)(3.0)[0].numpy(), 6.0)
def testMakeAttrShape(self):
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
@ -1057,8 +1070,8 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertEqual(
expected,
actual,
msg=('For shape %r, expected %r != %r actual' % (s, expected,
actual)))
msg=('For shape %r, expected %r != %r actual' %
(s, expected, actual)))
def testMakeAttrShapeList(self):
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
@ -1081,8 +1094,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
part = functools.partial(f, constant_op.constant(2.0))
self.assertAllEqual(
backprop.gradients_function(part)(constant_op.constant(1.0))[0],
2.0)
backprop.gradients_function(part)(constant_op.constant(1.0))[0], 2.0)
def testReturnSameThing(self):
@ -1238,10 +1250,11 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@custom_gradient.custom_gradient
def my_mul(x, y):
result = x*y
result = x * y
def grad(dr):
return [dr*y, dr*x]
return [dr * y, dr * x]
return result, grad
lr = 0.25
@ -1257,7 +1270,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
loss, grads_and_vars = loss_grads_fn(x)
losses.append(loss.numpy())
for (grad, var) in grads_and_vars:
var.assign_sub(lr*grad)
var.assign_sub(lr * grad)
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
@test_util.assert_no_new_tensors
@ -1276,7 +1289,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def testDifferentiatingFunctionThatReturnsNone(self):
def fn(x, y):
result = x*y # pylint: disable=unused-variable
result = x * y # pylint: disable=unused-variable
x = constant_op.constant(1)
y = constant_op.constant(2)
@ -1295,6 +1308,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
def testZerosCacheDoesntLeakAcrossGraphs(self):
with ops.Graph().as_default():
def get_grad():
with ops.Graph().as_default(), self.cached_session():
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
@ -1378,6 +1392,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testCustomGradientInEagerAndGraph(self):
@custom_gradient.custom_gradient
def f(x):
y = x * x
@ -1394,10 +1409,12 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(self.evaluate(t.gradient(g, c)), 4.0)
def testOverrideSecondOrderWithCustomGradient(self):
@custom_gradient.custom_gradient
def f(x):
def first_order_grad(dz):
@custom_gradient.custom_gradient
def first_order_custom(unused_x):
@ -1405,6 +1422,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
return -2.1 * ddz
return -1.1, h
return dz * first_order_custom(x)
return x + 10., first_order_grad
@ -1414,25 +1432,31 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
outer.watch(c)
with backprop.GradientTape() as inner:
inner.watch(c)
d = f(c) ** 4.
d = f(c)**4.
dd = inner.gradient(d, c)
self.assertAllClose(4. * f(c) ** 3. * -1.1, dd)
self.assertAllClose(3. * 4. * f(c) ** 2. * -1.1 * -1.1
+ 4. * f(c) ** 3. * -2.1,
self.assertAllClose(4. * f(c)**3. * -1.1, dd)
self.assertAllClose(3. * 4. * f(c)**2. * -1.1 * -1.1 + 4. * f(c)**3. * -2.1,
outer.gradient(dd, c))
@test_util.run_in_graph_and_eager_modes
def testCustomGradientForwardprop(self):
@custom_gradient.custom_gradient
def f(x):
z = 2. * tensor_util.constant_value(x)
def g(dz):
@custom_gradient.custom_gradient
def first_order(unused_x, unused_dz):
def second_order_and_transpose(unused_ddz):
return 2.2, 3.1
return 2.1, second_order_and_transpose
return first_order(x, dz)
return z, g
with backprop.GradientTape(persistent=True) as t:
@ -1457,9 +1481,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMaxPooling3DGradient(self):
if test.is_built_with_rocm():
self.skipTest('Pooling with 3D tensors is not supported in ROCm')
def forward(a):
r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME')
return r
@ -1491,9 +1512,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with backprop.GradientTape() as t:
values = constant_op.constant([1.0, 2.0], dtypes.float32)
s = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]],
values=values,
dense_shape=[3, 4])
indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4])
t.watch(s)
z = sparse_ops.sparse_reduce_sum_v2(s)
result = t.gradient(z, values)
@ -1529,6 +1548,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
self.assertEqual((z,), tape.watched_variables())
def testNameScope(self):
def fn(x):
with ops.name_scope('my_scope'):
a = math_ops.cos(x)
@ -1592,8 +1612,8 @@ class JacobianTest(test.TestCase):
g.watch(x)
g.watch(y)
z = x * x * y
jacobian = g.jacobian(z, [x, y],
experimental_use_pfor=experimental_use_pfor)
jacobian = g.jacobian(
z, [x, y], experimental_use_pfor=experimental_use_pfor)
answer = [array_ops.diag(2 * x * y), array_ops.diag(x * x)]
return jacobian, answer
@ -1648,7 +1668,8 @@ class JacobianTest(test.TestCase):
x = constant_op.constant([[1., 2], [3, 4]])
g.watch(x)
y = math_ops.matmul(x, x)
self.assertAllClose(g.jacobian(y, x, parallel_iterations=2),
self.assertAllClose(
g.jacobian(y, x, parallel_iterations=2),
g.jacobian(y, x, parallel_iterations=3))
@test_util.run_in_graph_and_eager_modes
@ -1690,7 +1711,8 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
z = x * x * y
batch_jacobian = g.batch_jacobian(
z, x, experimental_use_pfor=experimental_use_pfor)
answer = array_ops.stack([array_ops.diag(2 * x[0] * y[0]),
answer = array_ops.stack(
[array_ops.diag(2 * x[0] * y[0]),
array_ops.diag(2 * x[1] * y[1])])
return batch_jacobian, answer
@ -1757,13 +1779,11 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
g.watch(x)
w = constant_op.constant([[1., 2, 3, 4], [5, 6, 7, 8]])
y = math_ops.matmul(x, w)
self.assertAllClose(g.batch_jacobian(y, x, parallel_iterations=2),
self.assertAllClose(
g.batch_jacobian(y, x, parallel_iterations=2),
g.batch_jacobian(y, x, parallel_iterations=3))
@parameterized.parameters(
(True, True),
(True, False),
(False, True),
@parameterized.parameters((True, True), (True, False), (False, True),
(False, False))
def test_degenerate_shape(self, use_function, use_pfor):

View File

@ -2989,7 +2989,6 @@ cuda_py_test(
name = "pooling_ops_3d_test",
size = "medium",
srcs = ["pooling_ops_3d_test.py"],
tags = ["no_rocm"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -65,8 +65,8 @@ def pool_direct_single_axis(
input_size = input.shape[axis]
if padding == "SAME":
output_size = int(math.ceil(input_size / stride))
total_padding_amount = max(
0, (output_size - 1) * stride + effective_window_size - input_size)
total_padding_amount = max(0, (output_size - 1) * stride +
effective_window_size - input_size)
before_padding = total_padding_amount // 2
elif padding == "VALID":
output_size = int(
@ -219,8 +219,6 @@ class PoolingTest(test.TestCase):
strides=strides)
def testPool3D(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
with self.session(use_gpu=test.is_gpu_available()):
for padding in ["SAME", "VALID"]:
for pooling_type in ["MAX", "AVG"]:
@ -302,8 +300,10 @@ class PoolingTest(test.TestCase):
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
output = nn_ops.pool(input=x, **kwargs)
y_shape = output.get_shape().as_list()
err = gradient_checker.compute_gradient_error(
[x], [input_shape], output, y_shape, x_init_value=[x_val])
err = gradient_checker.compute_gradient_error([x], [input_shape],
output,
y_shape,
x_init_value=[x_val])
err_tolerance = 1e-2
self.assertLess(err, err_tolerance)
@ -363,8 +363,6 @@ class PoolingTest(test.TestCase):
@test_util.run_deprecated_v1
def testGradient3D(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
with self.session(use_gpu=test.is_gpu_available()):
for padding in ["SAME", "VALID"]:
for pooling_type in ["AVG", "MAX"]:

View File

@ -435,11 +435,7 @@ class NNTest(PForTestCase):
with g:
x1 = array_ops.gather(x, i)
output = nn.avg_pool3d(
x1,
ksize,
strides=strides,
padding="VALID",
data_format="NDHWC")
x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
loss = nn.l2_loss(output)
return output, g.gradient(loss, x1)
@ -488,8 +484,6 @@ class NNTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_max_pool3d(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
with backprop.GradientTape(persistent=True) as g:
x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
g.watch(x)
@ -1012,12 +1006,12 @@ class TensorListTest(PForTestCase):
# TensorListReserve operation.
v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled()
control_flow_v2_toggles.enable_control_flow_v2()
def loop_fn(i):
handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
_, out_handle = control_flow_ops.while_loop(
lambda j, _: j < 2,
lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
(0, handle))
lambda j, _: j < 2, lambda j, h:
(j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle))
return list_ops.tensor_list_stack(out_handle, dtypes.int32)
self._test_loop_fn(loop_fn, 2)
@ -1140,9 +1134,8 @@ class WhileV1Test(PForTestCase):
def loop_fn(_):
return control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
[0, 0.])[0]
lambda j, x: j < 4, lambda j, x:
(j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
self._test_loop_fn(loop_fn, 3)
@ -1150,9 +1143,8 @@ class WhileV1Test(PForTestCase):
def test_while_unstacked_condition(self):
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + i), [0, 0])
return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
(j + 1, x + i), [0, 0])
self._test_loop_fn(loop_fn, 3)
@ -1166,8 +1158,8 @@ class WhileV1Test(PForTestCase):
lengths_i = array_ops.gather(lengths, i)
_, total = control_flow_ops.while_loop(
lambda j, _: j < lengths_i,
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
lambda j, _: j < lengths_i, lambda j, t:
(j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
return total
self._test_loop_fn(loop_fn, 3)
@ -1371,6 +1363,7 @@ class WhileV2Test(PForTestCase):
super(WhileV2Test, self).tearDown()
def test_while_outside_loop(self):
def _f():
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
@ -1399,9 +1392,8 @@ class WhileV2Test(PForTestCase):
def loop_fn(_):
j, _ = control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
[0, 0.])
lambda j, x: j < 4, lambda j, x:
(j + 1, x + random_ops.random_uniform([])), [0, 0.])
return j
self._test_loop_fn(loop_fn, 3)
@ -1410,9 +1402,8 @@ class WhileV2Test(PForTestCase):
v = resource_variable_ops.ResourceVariable(5.)
def loop_fn(_):
_, output = control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + v), [0, 0.])
_, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
(j + 1, x + v), [0, 0.])
return output
self._test_loop_fn(loop_fn, 3)
@ -1420,9 +1411,8 @@ class WhileV2Test(PForTestCase):
def test_while_unstacked_condition(self):
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + i), [0, 0])
return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
(j + 1, x + i), [0, 0])
self._test_loop_fn(loop_fn, 3)
@ -1435,8 +1425,8 @@ class WhileV2Test(PForTestCase):
lengths_i = array_ops.gather(lengths, i)
return control_flow_ops.while_loop(
lambda j, _: j < lengths_i,
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
lambda j, _: j < lengths_i, lambda j, t:
(j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
self._test_loop_fn(loop_fn, 3)
@ -1446,25 +1436,29 @@ class WhileV2Test(PForTestCase):
# It also test inputs that are passed through.
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, *_: j < i,
lambda j, x, y, z, w: (j + 1, x + i, y + x, z, w),
[0,
lambda j, *_: j < i, lambda j, x, y, z, w:
(j + 1, x + i, y + x, z, w), [
0,
constant_op.constant(0),
constant_op.constant(1),
i,
constant_op.constant(2)])
constant_op.constant(1), i,
constant_op.constant(2)
])
self._test_loop_fn(loop_fn, 3)
def test_while_shape_invariants(self):
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, *_: j < 4,
lambda j, x, y: (j + 1, x + i, y + 1),
[0, constant_op.constant([0, 1]), constant_op.constant([2, 3])],
shape_invariants=[None,
[0, constant_op.constant([0, 1]),
constant_op.constant([2, 3])],
shape_invariants=[
None,
tensor_shape.TensorShape([2]),
tensor_shape.TensorShape([2])])
tensor_shape.TensorShape([2])
])
self._test_loop_fn(loop_fn, 3)
@ -1486,8 +1480,8 @@ class WhileV2Test(PForTestCase):
if use_pfor:
return pfor_control_flow_ops.pfor(loop_fn, iters=3)
else:
return pfor_control_flow_ops.for_loop(loop_fn, iters=3,
loop_fn_dtypes=out.dtype)
return pfor_control_flow_ops.for_loop(
loop_fn, iters=3, loop_fn_dtypes=out.dtype)
x = constant_op.constant(np.random.uniform(size=(1, 3)))
y = constant_op.constant(np.random.uniform(size=(3, 3)))
@ -1512,9 +1506,9 @@ class NestedControlFlowTest(PForTestCase):
f = lambda x, y: (x, y)
def _f(x, y):
return control_flow_ops.cond(y > split,
lambda: f(x, y),
lambda: (x + 1., y))
return control_flow_ops.cond(y > split, lambda: f(x, y), lambda:
(x + 1., y))
return _f
def _while(self, f=None):
@ -1523,9 +1517,8 @@ class NestedControlFlowTest(PForTestCase):
def _f(x, y):
return control_flow_ops.while_loop(
lambda j, _: j < y,
lambda j, t: (j + 1, t + array_ops.gather(f(x, y)[0], j)),
[0, x])[1], y
lambda j, _: j < y, lambda j, t:
(j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y
return _f
@ -1566,10 +1559,8 @@ class StatelessIfTest(PForTestCase):
x_i = array_ops.gather(x, i)
# Note that the output has a combination of then and else branches being
# loop variant / invariant.
return cond_v2.cond_v2(
x_i < y,
lambda: (y - x_i, y, 1., 2.),
lambda: (x_i - y, 0., y, 3.))
return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda:
(x_i - y, 0., y, 3.))
self._test_loop_fn(loop_fn, iters=5)
@ -1583,10 +1574,8 @@ class StatelessIfTest(PForTestCase):
x_i = array_ops.gather(x, i)
# Note that the output has a combination of then and else branches being
# loop variant / invariant.
return cond_v2.cond_v2(
z < y,
lambda: (y - x_i, y, 1., 2.),
lambda: (x_i - y, 0., y, 3.))
return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda:
(x_i - y, 0., y, 3.))
self._test_loop_fn(loop_fn, iters=5)
@ -1619,10 +1608,7 @@ class IfTest(PForTestCase):
@def_function.function
def loop_fn(i):
x_i = array_ops.gather(x, i)
return cond_v2.cond_v2(
x_i < y,
lambda: z - x_i,
lambda: z + x_i)
return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i)
self._test_loop_fn(loop_fn, iters=5)
@ -1736,9 +1722,8 @@ class Benchmarks(test.Benchmark):
with ops.Graph().as_default():
def loop_fn(i):
_, s = control_flow_ops.while_loop(lambda t, x: t < i,
lambda t, x: (t + 1, x + i),
[0, 0])
_, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
(t + 1, x + i), [0, 0])
return s
iters = 50
@ -2122,5 +2107,6 @@ class VariableTest(PForTestCase):
self._test_loop_fn(loop_fn, 2)
if __name__ == "__main__":
test.main()

View File

@ -263,7 +263,8 @@ namespace wrap {
__macro(miopenFindConvolutionForwardAlgorithm) \
__macro(miopenCreateTensorDescriptor) \
__macro(miopenDestroyTensorDescriptor) \
__macro(miopenSet2dPoolingDescriptor) \
__macro(miopenSetNdPoolingDescriptor) \
__macro(miopenSetPoolingIndexType) \
__macro(miopenSetLRNDescriptor) \
__macro(miopenLRNGetWorkSpaceSize) \
__macro(miopenCreateConvolutionDescriptor) \
@ -290,7 +291,7 @@ namespace wrap {
__macro(miopenSetTensorDescriptor) \
__macro(miopenGetTensorDescriptorSize) \
__macro(miopenPoolingForward) \
__macro(miopenPoolingGetWorkSpaceSize) \
__macro(miopenPoolingGetWorkSpaceSizeV2 \
__macro(miopenPoolingBackward) \
__macro(miopenLRNForward) \
__macro(miopenLRNBackward) \
@ -605,6 +606,11 @@ MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
// swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
&use_immediate_mode_);
bool enable_pooling_cache = false;
tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
&enable_pooling_cache);
if (enable_pooling_cache) m_pooling_cache_allowed = true;
}
port::Status MIOpenSupport::Init() {
@ -844,17 +850,19 @@ class ScopedPoolingDescriptor {
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>);
if (nd != 2) {
LOG(FATAL) << "miopen requires pooling dimensions be 2"
<< ToString(status);
}
status = wrap::miopenSet2dPoolingDescriptor(
status = wrap::miopenSetNdPoolingDescriptor(
handle_,
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? miopenPoolingMax
: miopenPoolingAverage),
shape[0], shape[1], padding[0], padding[1], strides[0], strides[1]);
nd, shape.data(), padding.data(), strides.data());
// Note: The index type has to be uint32 type for now because MIOpen
// API assumes all input indexes to be the same type. Since a tensor
// descriptor can only use int32 type, the index type here need to be
// aligned with the tensor index type of the (input) tensor descritptor
status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "could not set miopen pooling descriptor: "
<< ToString(status);
@ -4009,10 +4017,94 @@ bool MIOpenSupport::DoPoolForward(
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
LOG(ERROR) << "miopen does not support pooling for dobule type yet";
LOG(ERROR) << "miopen does not support pooling for double type yet";
return false;
}
bool PoolingWorkspaceDescriptor::IsSame(
const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
return dtype == _type &&
input_dims ==
input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
output_dims ==
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
op.mode() == pooling_dimensions.mode() &&
op.window() == pooling_dimensions.window() &&
op.padding() == pooling_dimensions.padding() &&
op.strides() == pooling_dimensions.strides();
}
bool PoolingWorkspaceCache::find(
const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
PoolingWorkspaceDescriptor*& pdesc) {
pdesc = 0;
auto it = cache.find(p);
if (it == cache.end()) {
return false;
}
if (!it->second.IsSame(input_dimensions, output_dimensions,
pooling_dimensions, _type)) {
return false;
}
pdesc = &it->second;
return true;
}
void PoolingWorkspaceCache::insert(
const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
hipStream_t hip_stream) {
PoolingWorkspaceDescriptor* desc = 0;
auto it = cache.find(p);
if (it != cache.end()) {
// replacing an entry with the same pointer but different attributes
// (if everything matches, the caller is expected to reuse the entry)
desc = &it->second;
hipStreamSynchronize(hip_stream);
memory_used -= desc->workspace_size;
} else {
cache[p] = PoolingWorkspaceDescriptor();
desc = &cache[p];
}
desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
desc->output_dims =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
desc->op = pooling_dimensions;
desc->dtype = _type;
desc->timestamp = timestamp;
timestamp++;
desc->workspace = std::move(workspace);
desc->workspace_size = wsp_size;
memory_used += wsp_size;
trim(hip_stream);
}
void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
if (memory_used < memory_budget && cache.size() < trim_size) return;
bool must_sync = true;
while (true) {
int new_size = cache.size() - (cache.size() >> 2);
std::vector<const void*> old_entries;
for (auto& x : cache)
if (x.second.timestamp + new_size < timestamp)
old_entries.push_back(x.first);
if (old_entries.empty()) break;
if (must_sync) hipStreamSynchronize(hip_stream);
must_sync = true;
for (auto x : old_entries) {
memory_used -= cache[x].workspace_size;
cache.erase(x);
}
if (memory_used < memory_budget || cache.size() < 10) break;
}
}
bool MIOpenSupport::DoPoolForward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
@ -4020,7 +4112,6 @@ bool MIOpenSupport::DoPoolForward(
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
@ -4030,10 +4121,48 @@ bool MIOpenSupport::DoPoolForward(
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
bool do_backward = false;
uint8* workspace = 0;
size_t workspace_size = 0;
std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
if (m_pooling_cache_enabled) {
do_backward = true;
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
pooling_desc.handle(), dest_desc.handle(), &workspace_size);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
if (workspace_size != 0) {
PoolingWorkspaceDescriptor* pdesc = 0;
bool cache_hit =
m_pooling_cache_allowed &&
m_pooling_cache.find(input_data.opaque(), input_dimensions,
output_dimensions, pooling_dimensions,
miopenFloat, pdesc);
if (cache_hit) {
// reusing the same buffer
workspace = reinterpret_cast<uint8*>(
pdesc->workspace->mutable_device_memory()->opaque());
} else {
wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size)
.ConsumeValueOrDie();
workspace = reinterpret_cast<uint8*>(
wsp_mem->mutable_device_memory()->opaque());
m_pooling_cache.insert(input_data.opaque(), input_dimensions,
output_dimensions, pooling_dimensions,
miopenFloat, wsp_mem, workspace_size,
AsGpuStreamValue(stream));
}
}
}
auto status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
false, nullptr, 0);
do_backward, workspace, workspace_size);
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
<< ToString(status);
@ -4072,6 +4201,118 @@ bool MIOpenSupport::DoPoolForward(
return true;
}
template <class T>
bool MIOpenSupport::DoPoolBackwardImpl(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<T>& output_data, const DeviceMemory<T>& input_diff_data,
DeviceMemory<T>* output_diff_data, ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream);
if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
auto type =
std::is_same<T, float>::value
? miopenFloat
: (std::is_same<T, Eigen::half>::value ? miopenHalf
: (miopenDataType_t)-1);
ScopedTensorDescriptor src_desc{input_dimensions, type};
ScopedTensorDescriptor dest_desc{output_dimensions, type};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
uint8* workspace_ptr = 0;
DeviceMemory<uint8> workspace;
PoolingWorkspaceDescriptor* pdesc = 0;
size_t workspace_size_in_bytes = 0;
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
// Allocate the workspace.
if (workspace_size_in_bytes > 0) {
bool cache_hit = m_pooling_cache_allowed &&
m_pooling_cache.find(input_data.opaque(), input_dimensions,
output_dimensions, pooling_dimensions,
type, pdesc);
if (cache_hit) {
assert(pdesc != 0);
workspace_ptr = reinterpret_cast<uint8*>(
pdesc->workspace->mutable_device_memory()->opaque());
VLOG(1) << "Pooling cache hit";
} else {
VLOG(1) << "Pooling cache miss";
assert(workspace_allocator);
auto allocated =
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
DeviceMemory<uint8> dest2; // duplicated dest from forward:
int64 dest2_size = 0;
// miopen requires the strides and dims to be ordered as BDYX.
std::vector<int64> dims64 =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
// miopen does not use strides and must have 4D tensor.
// std::vector<int> dims(pooling_dimensions.ndims() + 2);
dest2_size = sizeof(T);
for (auto& x : dims64) dest2_size *= x;
if (dest2_size > 0) {
assert(workspace_allocator);
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
} else {
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward pooling";
}
status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
workspace.opaque(), workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to enqueue forward pooling (before backward) on stream: "
<< ToString(status);
return false;
}
workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
}
}
status = wrap::miopenPoolingBackward(
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
output_diff_data->opaque(), workspace_ptr);
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
}
bool MIOpenSupport::DoPoolBackward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
@ -4094,91 +4335,10 @@ bool MIOpenSupport::DoPoolBackward(
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data,
ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
DeviceMemory<uint8> workspace;
size_t workspace_size_in_bytes = 0;
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
&workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
// Allocate the workspace.
if (workspace_size_in_bytes > 0) {
assert(workspace_allocator);
auto allocated =
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
}
DeviceMemory<uint8> dest2; // duplicated dest from forward:
int dest2_size = 0;
// miopen requires the strides and dims to be ordered as BDYX.
std::vector<int64> dims64 =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
// miopen does not use strides and must have 4D tensor.
std::vector<int> dims(4);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
if (dest2_size > 0) {
assert(workspace_allocator);
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
} else {
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward pooling";
}
status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
workspace.opaque(), workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to enqueue forward pooling (before backward) on stream: "
<< ToString(status);
return false;
}
status = wrap::miopenPoolingBackward(
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
output_diff_data->opaque(), workspace.opaque());
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
input_diff_data, output_diff_data,
workspace_allocator);
}
bool MIOpenSupport::DoPoolBackward(
@ -4190,91 +4350,10 @@ bool MIOpenSupport::DoPoolBackward(
const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data,
ScratchAllocator* workspace_allocator) {
auto miopen = miopen_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
DeviceMemory<uint8> workspace;
size_t workspace_size_in_bytes = 0;
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
&workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to obtain workspace size for backward pooling on stream: "
<< ToString(status);
return false;
}
// Allocate the workspace.
if (workspace_size_in_bytes > 0) {
assert(workspace_allocator);
auto allocated =
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
}
DeviceMemory<uint8> dest2; // duplicated dest from forward:
int dest2_size = 0;
// miopen requires the strides and dims to be ordered as BDYX.
std::vector<int64> dims64 =
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
// miopen does not use strides and must have 4D tensor.
std::vector<int> dims(4);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
if (dest2_size > 0) {
assert(workspace_allocator);
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "Failed to allocate backward pooling workspace";
return false;
}
} else {
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward pooling";
}
status = wrap::miopenPoolingForward(
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
workspace.opaque(), workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(ERROR)
<< "failed to enqueue forward pooling (before backward) on stream: "
<< ToString(status);
return false;
}
status = wrap::miopenPoolingBackward(
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
output_diff_data->opaque(), workspace.opaque());
if (status != miopenStatusSuccess) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
input_diff_data, output_diff_data,
workspace_allocator);
}
bool MIOpenSupport::DoNormalizeWithDimensions(

View File

@ -20,6 +20,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
#include "absl/synchronization/mutex.h"
#include "rocm/include/miopen/miopen.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/status.h"
@ -38,6 +39,39 @@ class MIOpenCTCLossDescriptor;
// Opaque and unique identifier for the MIOpen plugin.
extern const PluginId kMIOpenPlugin;
struct PoolingWorkspaceDescriptor {
std::vector<int64> input_dims;
std::vector<int64> output_dims;
dnn::PoolingDescriptor op;
int dtype;
uint64_t timestamp;
std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
size_t workspace_size;
bool IsSame(const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type);
};
struct PoolingWorkspaceCache {
std::map<const void*, PoolingWorkspaceDescriptor> cache;
const int trim_size = 1000;
const uint64_t memory_budget = 2e7;
uint64_t timestamp = 0;
uint64_t memory_used = 0;
bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
PoolingWorkspaceDescriptor*& pdesc);
void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
const dnn::BatchDescriptor& output_dimensions,
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
size_t wsp_size, hipStream_t hip_stream);
private:
void trim(hipStream_t hip_stream);
};
// miopen-library based DNN support. For details on overridden interface
// functions, see dnn.h.
class MIOpenSupport : public dnn::DnnSupport {
@ -664,6 +698,10 @@ class MIOpenSupport : public dnn::DnnSupport {
// Provide access to the MIOpen handle.
std::unique_ptr<class MIOpenAccess> miopen_;
PoolingWorkspaceCache m_pooling_cache;
bool m_pooling_cache_allowed = false;
bool m_pooling_cache_enabled = false;
template <class T, class U>
bool DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType input_data_type,
@ -847,6 +885,36 @@ class MIOpenSupport : public dnn::DnnSupport {
ScratchAllocator* scratch_allocator,
std::vector<dnn::ProfileResult>* out_algorithms);
port::Status DoCtcLossImpl(
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const MIOpenRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory);
port::Status DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const dnn::RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) override;
template <class T>
bool DoPoolBackwardImpl(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<T>& output_data,
const DeviceMemory<T>& input_diff_data,
DeviceMemory<T>* output_diff_data,
ScratchAllocator* workspace_allocator = nullptr);
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
};