Merge pull request #40320 from ROCmSoftwarePlatform:google-upstream-eugene-3dpool
PiperOrigin-RevId: 317868942 Change-Id: I47a6de9d0c270a3d38abacd748fa57903e6b0673
This commit is contained in:
commit
048ff6ab46
@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
TensorShape x_shape({1, 2, 2, 1});
|
||||
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, LRN) {
|
||||
TensorShape x_shape({1, 1, 2, 1});
|
||||
|
@ -98,10 +98,25 @@ void DnnPooling3dOp<T>::Compute(OpKernelContext* context,
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
|
||||
// default value is in bytes despite the name of the environment variable
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||
);
|
||||
|
||||
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
|
||||
bool status =
|
||||
stream
|
||||
->ThenPoolForward(pooling_desc, input_desc, input_data, output_desc,
|
||||
&output_data, &scratch_allocator)
|
||||
.ok();
|
||||
#else
|
||||
bool status = stream
|
||||
->ThenPoolForward(pooling_desc, input_desc, input_data,
|
||||
output_desc, &output_data)
|
||||
.ok();
|
||||
#endif
|
||||
|
||||
OP_REQUIRES(context, status,
|
||||
errors::Internal("dnn PoolForward launch failed"));
|
||||
|
||||
@ -225,12 +240,28 @@ void DnnPooling3dGradOp<T>::Compute(
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
static int64 PoolingScratchSize = GetDnnWorkspaceLimit(
|
||||
// default value is in bytes despite the name of the environment variable
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
|
||||
);
|
||||
|
||||
DnnScratchAllocator scratch_allocator(PoolingScratchSize, context);
|
||||
bool status = stream
|
||||
->ThenPoolBackward(pooling_desc, orig_input_desc,
|
||||
orig_input_data, orig_output_desc,
|
||||
orig_output_data, output_backprop_data,
|
||||
&input_backprop_data, &scratch_allocator)
|
||||
.ok();
|
||||
#else
|
||||
bool status =
|
||||
stream
|
||||
->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
|
||||
orig_output_desc, orig_output_data,
|
||||
output_backprop_data, &input_backprop_data)
|
||||
.ok();
|
||||
#endif
|
||||
|
||||
OP_REQUIRES(context, status,
|
||||
errors::Internal("dnn PoolBackward launch failed"));
|
||||
|
||||
|
@ -84,8 +84,9 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
tf_y = tf_g1 * tf_g2 * tf_g3
|
||||
tf_grad = gradients.gradients(tf_y, [tf_var])[0]
|
||||
|
||||
tf_dense_grad = math_ops.unsorted_segment_sum(
|
||||
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
|
||||
tf_dense_grad = math_ops.unsorted_segment_sum(tf_grad.values,
|
||||
tf_grad.indices,
|
||||
tf_grad.dense_shape[0])
|
||||
|
||||
self.assertAllClose(grad, self.evaluate(tf_dense_grad))
|
||||
|
||||
@ -127,8 +128,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(grads_and_vars[0][0], 1.0)
|
||||
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[('Function', def_function.function),
|
||||
@parameterized.named_parameters([('Function', def_function.function),
|
||||
('NoFunction', lambda f: f)])
|
||||
def testNoOpBehaviorConsistent(self, decorator):
|
||||
|
||||
@ -195,8 +195,10 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def identity(x):
|
||||
|
||||
def grad(_):
|
||||
return [] # This return value is wrong!
|
||||
|
||||
return x, grad
|
||||
|
||||
x = variables.Variable(1.0)
|
||||
@ -234,8 +236,10 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
|
||||
def grad(_):
|
||||
raise RuntimeError('x')
|
||||
|
||||
return x, grad
|
||||
|
||||
# TODO(apassos) raise the right error here
|
||||
@ -337,7 +341,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
x = constant_op.constant(2.0)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(x)
|
||||
y = x*x
|
||||
y = x * x
|
||||
self.assertEqual(t.gradient([x, y], x).numpy(), 5.0)
|
||||
|
||||
def testTapeNoOpGradientWithMultiTargetAllSource(self):
|
||||
@ -441,9 +445,11 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(t.gradient(loss, v), 2.0)
|
||||
|
||||
def testPythonMax(self):
|
||||
x = [resource_variable_ops.ResourceVariable(2.),
|
||||
x = [
|
||||
resource_variable_ops.ResourceVariable(2.),
|
||||
resource_variable_ops.ResourceVariable(3.),
|
||||
resource_variable_ops.ResourceVariable(5.)]
|
||||
resource_variable_ops.ResourceVariable(5.)
|
||||
]
|
||||
with backprop.GradientTape() as t:
|
||||
f = max(x)
|
||||
grad = t.gradient(f, x)
|
||||
@ -538,8 +544,8 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with backprop.GradientTape() as tape2:
|
||||
tape1.watch(x1)
|
||||
tape2.watch([x1, x2])
|
||||
y = x1 ** 3
|
||||
z = x2 ** 2
|
||||
y = x1**3
|
||||
z = x2**2
|
||||
dy, dz = tape2.gradient([y, z], [x1, x2])
|
||||
d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
|
||||
|
||||
@ -602,6 +608,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testArgmax(self):
|
||||
|
||||
def argmax(x):
|
||||
i = math_ops.argmax(x)
|
||||
return array_ops.stop_gradient(i)
|
||||
@ -612,6 +619,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_gpu_only
|
||||
@test_util.assert_no_new_tensors
|
||||
def testGPU(self):
|
||||
|
||||
def fn(x):
|
||||
with context.device('/gpu:0'):
|
||||
b = constant_op.constant(2.0)
|
||||
@ -634,8 +642,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with context.device('gpu:0'):
|
||||
return v.read_value()
|
||||
|
||||
self.assertEqual(
|
||||
backprop.implicit_grad(f)()[0][0].cpu().numpy(), 1.0)
|
||||
self.assertEqual(backprop.implicit_grad(f)()[0][0].cpu().numpy(), 1.0)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testCPU(self):
|
||||
@ -651,6 +658,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_gpu_only
|
||||
@test_util.assert_no_new_tensors
|
||||
def testTensorCopyGPU2CPU2GPU(self):
|
||||
|
||||
def f(a, b):
|
||||
return a.cpu() + b.cpu()
|
||||
|
||||
@ -675,8 +683,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testUnconnectedNone(self):
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
1.0, name='testUnconnectedNone')
|
||||
v = resource_variable_ops.ResourceVariable(1.0, name='testUnconnectedNone')
|
||||
|
||||
def f():
|
||||
v.read_value()
|
||||
@ -690,9 +697,9 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with g:
|
||||
x = constant_op.constant(3.0)
|
||||
g.watch(x)
|
||||
y = 2*x
|
||||
y = 2 * x
|
||||
with g:
|
||||
z = 2*y
|
||||
z = 2 * y
|
||||
grad = g.gradient(target=z, sources=[x])
|
||||
self.assertEqual(self.evaluate(grad), [4.0])
|
||||
|
||||
@ -736,12 +743,20 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(g.gradient(y, x1)), [1.0])
|
||||
self.assertEqual(self.evaluate(g.gradient(y, (x1,))), (1.0,))
|
||||
self.assertEqual(self.evaluate(g.gradient(y, (x1, x2))), (1.0, 2.0))
|
||||
self.assertEqual(self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])),
|
||||
[(1.0, 2.0), (2.0, 3.0)])
|
||||
self.assertEqual(self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
|
||||
self.assertEqual(
|
||||
self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])), [(1.0, 2.0),
|
||||
(2.0, 3.0)])
|
||||
self.assertEqual(
|
||||
self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
|
||||
(1.0, 2.0, [1.0, 3.0]))
|
||||
self.assertEqual(self.evaluate(g.gradient(y, [x1, {'x2': x2, 'x3': x3}])),
|
||||
[1.0, {'x2': 2.0, 'x3': 3.0}])
|
||||
self.assertEqual(
|
||||
self.evaluate(g.gradient(y, [x1, {
|
||||
'x2': x2,
|
||||
'x3': x3
|
||||
}])), [1.0, {
|
||||
'x2': 2.0,
|
||||
'x3': 3.0
|
||||
}])
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@ -846,13 +861,13 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
x = constant_op.constant(3.0)
|
||||
g.watch(x)
|
||||
y = x ** 3 # y := x^3
|
||||
y = x**3 # y := x^3
|
||||
dy_dx = g.gradient(y, x) # dy/dx := 3x^2
|
||||
d2y_dx2 = g.gradient(dy_dx, x) # d2y/dx2 := 6x
|
||||
d3y_dx3 = g.gradient(d2y_dx2, x) # d3y/dx3 := 6
|
||||
x = 3
|
||||
self.assertEqual(self.evaluate(y), x ** 3)
|
||||
self.assertEqual(self.evaluate(dy_dx), 3 * x ** 2)
|
||||
self.assertEqual(self.evaluate(y), x**3)
|
||||
self.assertEqual(self.evaluate(dy_dx), 3 * x**2)
|
||||
self.assertEqual(self.evaluate(d2y_dx2), 6 * x)
|
||||
self.assertEqual(self.evaluate(d3y_dx3), 6)
|
||||
del g
|
||||
@ -973,19 +988,17 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
x = constant_op.constant(1.)
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch(x)
|
||||
tape_lib.record_operation(
|
||||
'InvalidBackprop',
|
||||
[y],
|
||||
[x],
|
||||
lambda dy: [])
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InternalError, 'InvalidBackprop.*too few gradients'):
|
||||
tape_lib.record_operation('InvalidBackprop', [y], [x], lambda dy: [])
|
||||
with self.assertRaisesRegexp(errors_impl.InternalError,
|
||||
'InvalidBackprop.*too few gradients'):
|
||||
g.gradient(y, x)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testEmptyParamsForValueAndGradFunction(self):
|
||||
|
||||
def fn(a, b):
|
||||
return a * b
|
||||
|
||||
val_and_grads_fn = backprop.val_and_grad_function(fn)
|
||||
|
||||
x = 2.0
|
||||
@ -997,8 +1010,10 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
def testNonEmptyParamsForValueAndGradFunction(self):
|
||||
|
||||
def fn(a, b):
|
||||
return a * b
|
||||
|
||||
val_and_grad_fn = backprop.val_and_grad_function(fn, params=[1])
|
||||
|
||||
x = 2.0
|
||||
@ -1046,9 +1061,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
def mul(x):
|
||||
return math_ops._mul_dispatch(x, x) # pylint: disable=protected-access
|
||||
|
||||
self.assertAllEqual(
|
||||
backprop.gradients_function(mul)(3.0)[0].numpy(),
|
||||
6.0)
|
||||
self.assertAllEqual(backprop.gradients_function(mul)(3.0)[0].numpy(), 6.0)
|
||||
|
||||
def testMakeAttrShape(self):
|
||||
for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
|
||||
@ -1057,8 +1070,8 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(
|
||||
expected,
|
||||
actual,
|
||||
msg=('For shape %r, expected %r != %r actual' % (s, expected,
|
||||
actual)))
|
||||
msg=('For shape %r, expected %r != %r actual' %
|
||||
(s, expected, actual)))
|
||||
|
||||
def testMakeAttrShapeList(self):
|
||||
shape_list = [[], None, [1, 2, 3], [None, None], [1, None, 3]]
|
||||
@ -1081,8 +1094,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
part = functools.partial(f, constant_op.constant(2.0))
|
||||
self.assertAllEqual(
|
||||
backprop.gradients_function(part)(constant_op.constant(1.0))[0],
|
||||
2.0)
|
||||
backprop.gradients_function(part)(constant_op.constant(1.0))[0], 2.0)
|
||||
|
||||
def testReturnSameThing(self):
|
||||
|
||||
@ -1238,10 +1250,11 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def my_mul(x, y):
|
||||
result = x*y
|
||||
result = x * y
|
||||
|
||||
def grad(dr):
|
||||
return [dr*y, dr*x]
|
||||
return [dr * y, dr * x]
|
||||
|
||||
return result, grad
|
||||
|
||||
lr = 0.25
|
||||
@ -1257,7 +1270,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
loss, grads_and_vars = loss_grads_fn(x)
|
||||
losses.append(loss.numpy())
|
||||
for (grad, var) in grads_and_vars:
|
||||
var.assign_sub(lr*grad)
|
||||
var.assign_sub(lr * grad)
|
||||
self.assertAllEqual(losses, [4.0, 3., 2., 1., 0.])
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@ -1276,7 +1289,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
def testDifferentiatingFunctionThatReturnsNone(self):
|
||||
|
||||
def fn(x, y):
|
||||
result = x*y # pylint: disable=unused-variable
|
||||
result = x * y # pylint: disable=unused-variable
|
||||
|
||||
x = constant_op.constant(1)
|
||||
y = constant_op.constant(2)
|
||||
@ -1295,6 +1308,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testZerosCacheDoesntLeakAcrossGraphs(self):
|
||||
with ops.Graph().as_default():
|
||||
|
||||
def get_grad():
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
|
||||
@ -1378,6 +1392,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCustomGradientInEagerAndGraph(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
y = x * x
|
||||
@ -1394,10 +1409,12 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(self.evaluate(t.gradient(g, c)), 4.0)
|
||||
|
||||
def testOverrideSecondOrderWithCustomGradient(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
|
||||
def first_order_grad(dz):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def first_order_custom(unused_x):
|
||||
|
||||
@ -1405,6 +1422,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
return -2.1 * ddz
|
||||
|
||||
return -1.1, h
|
||||
|
||||
return dz * first_order_custom(x)
|
||||
|
||||
return x + 10., first_order_grad
|
||||
@ -1414,25 +1432,31 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
outer.watch(c)
|
||||
with backprop.GradientTape() as inner:
|
||||
inner.watch(c)
|
||||
d = f(c) ** 4.
|
||||
d = f(c)**4.
|
||||
dd = inner.gradient(d, c)
|
||||
self.assertAllClose(4. * f(c) ** 3. * -1.1, dd)
|
||||
self.assertAllClose(3. * 4. * f(c) ** 2. * -1.1 * -1.1
|
||||
+ 4. * f(c) ** 3. * -2.1,
|
||||
self.assertAllClose(4. * f(c)**3. * -1.1, dd)
|
||||
self.assertAllClose(3. * 4. * f(c)**2. * -1.1 * -1.1 + 4. * f(c)**3. * -2.1,
|
||||
outer.gradient(dd, c))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCustomGradientForwardprop(self):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def f(x):
|
||||
z = 2. * tensor_util.constant_value(x)
|
||||
|
||||
def g(dz):
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def first_order(unused_x, unused_dz):
|
||||
|
||||
def second_order_and_transpose(unused_ddz):
|
||||
return 2.2, 3.1
|
||||
|
||||
return 2.1, second_order_and_transpose
|
||||
|
||||
return first_order(x, dz)
|
||||
|
||||
return z, g
|
||||
|
||||
with backprop.GradientTape(persistent=True) as t:
|
||||
@ -1457,9 +1481,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMaxPooling3DGradient(self):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest('Pooling with 3D tensors is not supported in ROCm')
|
||||
|
||||
def forward(a):
|
||||
r = max_pooling3d(a, pool_size=pool_size, strides=strides, padding='SAME')
|
||||
return r
|
||||
@ -1491,9 +1512,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with backprop.GradientTape() as t:
|
||||
values = constant_op.constant([1.0, 2.0], dtypes.float32)
|
||||
s = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]],
|
||||
values=values,
|
||||
dense_shape=[3, 4])
|
||||
indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4])
|
||||
t.watch(s)
|
||||
z = sparse_ops.sparse_reduce_sum_v2(s)
|
||||
result = t.gradient(z, values)
|
||||
@ -1529,6 +1548,7 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual((z,), tape.watched_variables())
|
||||
|
||||
def testNameScope(self):
|
||||
|
||||
def fn(x):
|
||||
with ops.name_scope('my_scope'):
|
||||
a = math_ops.cos(x)
|
||||
@ -1592,8 +1612,8 @@ class JacobianTest(test.TestCase):
|
||||
g.watch(x)
|
||||
g.watch(y)
|
||||
z = x * x * y
|
||||
jacobian = g.jacobian(z, [x, y],
|
||||
experimental_use_pfor=experimental_use_pfor)
|
||||
jacobian = g.jacobian(
|
||||
z, [x, y], experimental_use_pfor=experimental_use_pfor)
|
||||
answer = [array_ops.diag(2 * x * y), array_ops.diag(x * x)]
|
||||
return jacobian, answer
|
||||
|
||||
@ -1648,7 +1668,8 @@ class JacobianTest(test.TestCase):
|
||||
x = constant_op.constant([[1., 2], [3, 4]])
|
||||
g.watch(x)
|
||||
y = math_ops.matmul(x, x)
|
||||
self.assertAllClose(g.jacobian(y, x, parallel_iterations=2),
|
||||
self.assertAllClose(
|
||||
g.jacobian(y, x, parallel_iterations=2),
|
||||
g.jacobian(y, x, parallel_iterations=3))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@ -1690,7 +1711,8 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
z = x * x * y
|
||||
batch_jacobian = g.batch_jacobian(
|
||||
z, x, experimental_use_pfor=experimental_use_pfor)
|
||||
answer = array_ops.stack([array_ops.diag(2 * x[0] * y[0]),
|
||||
answer = array_ops.stack(
|
||||
[array_ops.diag(2 * x[0] * y[0]),
|
||||
array_ops.diag(2 * x[1] * y[1])])
|
||||
return batch_jacobian, answer
|
||||
|
||||
@ -1757,13 +1779,11 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
g.watch(x)
|
||||
w = constant_op.constant([[1., 2, 3, 4], [5, 6, 7, 8]])
|
||||
y = math_ops.matmul(x, w)
|
||||
self.assertAllClose(g.batch_jacobian(y, x, parallel_iterations=2),
|
||||
self.assertAllClose(
|
||||
g.batch_jacobian(y, x, parallel_iterations=2),
|
||||
g.batch_jacobian(y, x, parallel_iterations=3))
|
||||
|
||||
@parameterized.parameters(
|
||||
(True, True),
|
||||
(True, False),
|
||||
(False, True),
|
||||
@parameterized.parameters((True, True), (True, False), (False, True),
|
||||
(False, False))
|
||||
def test_degenerate_shape(self, use_function, use_pfor):
|
||||
|
||||
|
@ -2989,7 +2989,6 @@ cuda_py_test(
|
||||
name = "pooling_ops_3d_test",
|
||||
size = "medium",
|
||||
srcs = ["pooling_ops_3d_test.py"],
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
@ -65,8 +65,8 @@ def pool_direct_single_axis(
|
||||
input_size = input.shape[axis]
|
||||
if padding == "SAME":
|
||||
output_size = int(math.ceil(input_size / stride))
|
||||
total_padding_amount = max(
|
||||
0, (output_size - 1) * stride + effective_window_size - input_size)
|
||||
total_padding_amount = max(0, (output_size - 1) * stride +
|
||||
effective_window_size - input_size)
|
||||
before_padding = total_padding_amount // 2
|
||||
elif padding == "VALID":
|
||||
output_size = int(
|
||||
@ -219,8 +219,6 @@ class PoolingTest(test.TestCase):
|
||||
strides=strides)
|
||||
|
||||
def testPool3D(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
|
||||
with self.session(use_gpu=test.is_gpu_available()):
|
||||
for padding in ["SAME", "VALID"]:
|
||||
for pooling_type in ["MAX", "AVG"]:
|
||||
@ -302,8 +300,10 @@ class PoolingTest(test.TestCase):
|
||||
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
|
||||
output = nn_ops.pool(input=x, **kwargs)
|
||||
y_shape = output.get_shape().as_list()
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
[x], [input_shape], output, y_shape, x_init_value=[x_val])
|
||||
err = gradient_checker.compute_gradient_error([x], [input_shape],
|
||||
output,
|
||||
y_shape,
|
||||
x_init_value=[x_val])
|
||||
err_tolerance = 1e-2
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
@ -363,8 +363,6 @@ class PoolingTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradient3D(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
|
||||
with self.session(use_gpu=test.is_gpu_available()):
|
||||
for padding in ["SAME", "VALID"]:
|
||||
for pooling_type in ["AVG", "MAX"]:
|
||||
|
@ -435,11 +435,7 @@ class NNTest(PForTestCase):
|
||||
with g:
|
||||
x1 = array_ops.gather(x, i)
|
||||
output = nn.avg_pool3d(
|
||||
x1,
|
||||
ksize,
|
||||
strides=strides,
|
||||
padding="VALID",
|
||||
data_format="NDHWC")
|
||||
x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
|
||||
loss = nn.l2_loss(output)
|
||||
return output, g.gradient(loss, x1)
|
||||
|
||||
@ -488,8 +484,6 @@ class NNTest(PForTestCase):
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_max_pool3d(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("Pooling with 3D tensors is not supported in ROCm")
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
|
||||
g.watch(x)
|
||||
@ -1012,12 +1006,12 @@ class TensorListTest(PForTestCase):
|
||||
# TensorListReserve operation.
|
||||
v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled()
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
|
||||
def loop_fn(i):
|
||||
handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
|
||||
_, out_handle = control_flow_ops.while_loop(
|
||||
lambda j, _: j < 2,
|
||||
lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
|
||||
(0, handle))
|
||||
lambda j, _: j < 2, lambda j, h:
|
||||
(j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle))
|
||||
return list_ops.tensor_list_stack(out_handle, dtypes.int32)
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
@ -1140,9 +1134,8 @@ class WhileV1Test(PForTestCase):
|
||||
|
||||
def loop_fn(_):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
|
||||
[0, 0.])[0]
|
||||
lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -1150,9 +1143,8 @@ class WhileV1Test(PForTestCase):
|
||||
def test_while_unstacked_condition(self):
|
||||
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + i), [0, 0])
|
||||
return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + i), [0, 0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -1166,8 +1158,8 @@ class WhileV1Test(PForTestCase):
|
||||
lengths_i = array_ops.gather(lengths, i)
|
||||
|
||||
_, total = control_flow_ops.while_loop(
|
||||
lambda j, _: j < lengths_i,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
lambda j, _: j < lengths_i, lambda j, t:
|
||||
(j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
return total
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
@ -1371,6 +1363,7 @@ class WhileV2Test(PForTestCase):
|
||||
super(WhileV2Test, self).tearDown()
|
||||
|
||||
def test_while_outside_loop(self):
|
||||
|
||||
def _f():
|
||||
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
|
||||
|
||||
@ -1399,9 +1392,8 @@ class WhileV2Test(PForTestCase):
|
||||
|
||||
def loop_fn(_):
|
||||
j, _ = control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
|
||||
[0, 0.])
|
||||
lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + random_ops.random_uniform([])), [0, 0.])
|
||||
return j
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
@ -1410,9 +1402,8 @@ class WhileV2Test(PForTestCase):
|
||||
v = resource_variable_ops.ResourceVariable(5.)
|
||||
|
||||
def loop_fn(_):
|
||||
_, output = control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + v), [0, 0.])
|
||||
_, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + v), [0, 0.])
|
||||
return output
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
@ -1420,9 +1411,8 @@ class WhileV2Test(PForTestCase):
|
||||
def test_while_unstacked_condition(self):
|
||||
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + i), [0, 0])
|
||||
return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + i), [0, 0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -1435,8 +1425,8 @@ class WhileV2Test(PForTestCase):
|
||||
lengths_i = array_ops.gather(lengths, i)
|
||||
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, _: j < lengths_i,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
lambda j, _: j < lengths_i, lambda j, t:
|
||||
(j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -1446,25 +1436,29 @@ class WhileV2Test(PForTestCase):
|
||||
# It also test inputs that are passed through.
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, *_: j < i,
|
||||
lambda j, x, y, z, w: (j + 1, x + i, y + x, z, w),
|
||||
[0,
|
||||
lambda j, *_: j < i, lambda j, x, y, z, w:
|
||||
(j + 1, x + i, y + x, z, w), [
|
||||
0,
|
||||
constant_op.constant(0),
|
||||
constant_op.constant(1),
|
||||
i,
|
||||
constant_op.constant(2)])
|
||||
constant_op.constant(1), i,
|
||||
constant_op.constant(2)
|
||||
])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_shape_invariants(self):
|
||||
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, *_: j < 4,
|
||||
lambda j, x, y: (j + 1, x + i, y + 1),
|
||||
[0, constant_op.constant([0, 1]), constant_op.constant([2, 3])],
|
||||
shape_invariants=[None,
|
||||
[0, constant_op.constant([0, 1]),
|
||||
constant_op.constant([2, 3])],
|
||||
shape_invariants=[
|
||||
None,
|
||||
tensor_shape.TensorShape([2]),
|
||||
tensor_shape.TensorShape([2])])
|
||||
tensor_shape.TensorShape([2])
|
||||
])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -1486,8 +1480,8 @@ class WhileV2Test(PForTestCase):
|
||||
if use_pfor:
|
||||
return pfor_control_flow_ops.pfor(loop_fn, iters=3)
|
||||
else:
|
||||
return pfor_control_flow_ops.for_loop(loop_fn, iters=3,
|
||||
loop_fn_dtypes=out.dtype)
|
||||
return pfor_control_flow_ops.for_loop(
|
||||
loop_fn, iters=3, loop_fn_dtypes=out.dtype)
|
||||
|
||||
x = constant_op.constant(np.random.uniform(size=(1, 3)))
|
||||
y = constant_op.constant(np.random.uniform(size=(3, 3)))
|
||||
@ -1512,9 +1506,9 @@ class NestedControlFlowTest(PForTestCase):
|
||||
f = lambda x, y: (x, y)
|
||||
|
||||
def _f(x, y):
|
||||
return control_flow_ops.cond(y > split,
|
||||
lambda: f(x, y),
|
||||
lambda: (x + 1., y))
|
||||
return control_flow_ops.cond(y > split, lambda: f(x, y), lambda:
|
||||
(x + 1., y))
|
||||
|
||||
return _f
|
||||
|
||||
def _while(self, f=None):
|
||||
@ -1523,9 +1517,8 @@ class NestedControlFlowTest(PForTestCase):
|
||||
|
||||
def _f(x, y):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, _: j < y,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(f(x, y)[0], j)),
|
||||
[0, x])[1], y
|
||||
lambda j, _: j < y, lambda j, t:
|
||||
(j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y
|
||||
|
||||
return _f
|
||||
|
||||
@ -1566,10 +1559,8 @@ class StatelessIfTest(PForTestCase):
|
||||
x_i = array_ops.gather(x, i)
|
||||
# Note that the output has a combination of then and else branches being
|
||||
# loop variant / invariant.
|
||||
return cond_v2.cond_v2(
|
||||
x_i < y,
|
||||
lambda: (y - x_i, y, 1., 2.),
|
||||
lambda: (x_i - y, 0., y, 3.))
|
||||
return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda:
|
||||
(x_i - y, 0., y, 3.))
|
||||
|
||||
self._test_loop_fn(loop_fn, iters=5)
|
||||
|
||||
@ -1583,10 +1574,8 @@ class StatelessIfTest(PForTestCase):
|
||||
x_i = array_ops.gather(x, i)
|
||||
# Note that the output has a combination of then and else branches being
|
||||
# loop variant / invariant.
|
||||
return cond_v2.cond_v2(
|
||||
z < y,
|
||||
lambda: (y - x_i, y, 1., 2.),
|
||||
lambda: (x_i - y, 0., y, 3.))
|
||||
return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda:
|
||||
(x_i - y, 0., y, 3.))
|
||||
|
||||
self._test_loop_fn(loop_fn, iters=5)
|
||||
|
||||
@ -1619,10 +1608,7 @@ class IfTest(PForTestCase):
|
||||
@def_function.function
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return cond_v2.cond_v2(
|
||||
x_i < y,
|
||||
lambda: z - x_i,
|
||||
lambda: z + x_i)
|
||||
return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i)
|
||||
|
||||
self._test_loop_fn(loop_fn, iters=5)
|
||||
|
||||
@ -1736,9 +1722,8 @@ class Benchmarks(test.Benchmark):
|
||||
with ops.Graph().as_default():
|
||||
|
||||
def loop_fn(i):
|
||||
_, s = control_flow_ops.while_loop(lambda t, x: t < i,
|
||||
lambda t, x: (t + 1, x + i),
|
||||
[0, 0])
|
||||
_, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
|
||||
(t + 1, x + i), [0, 0])
|
||||
return s
|
||||
|
||||
iters = 50
|
||||
@ -2122,5 +2107,6 @@ class VariableTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -263,7 +263,8 @@ namespace wrap {
|
||||
__macro(miopenFindConvolutionForwardAlgorithm) \
|
||||
__macro(miopenCreateTensorDescriptor) \
|
||||
__macro(miopenDestroyTensorDescriptor) \
|
||||
__macro(miopenSet2dPoolingDescriptor) \
|
||||
__macro(miopenSetNdPoolingDescriptor) \
|
||||
__macro(miopenSetPoolingIndexType) \
|
||||
__macro(miopenSetLRNDescriptor) \
|
||||
__macro(miopenLRNGetWorkSpaceSize) \
|
||||
__macro(miopenCreateConvolutionDescriptor) \
|
||||
@ -290,7 +291,7 @@ namespace wrap {
|
||||
__macro(miopenSetTensorDescriptor) \
|
||||
__macro(miopenGetTensorDescriptorSize) \
|
||||
__macro(miopenPoolingForward) \
|
||||
__macro(miopenPoolingGetWorkSpaceSize) \
|
||||
__macro(miopenPoolingGetWorkSpaceSizeV2 \
|
||||
__macro(miopenPoolingBackward) \
|
||||
__macro(miopenLRNForward) \
|
||||
__macro(miopenLRNBackward) \
|
||||
@ -605,6 +606,11 @@ MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
|
||||
// swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
|
||||
tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
|
||||
&use_immediate_mode_);
|
||||
|
||||
bool enable_pooling_cache = false;
|
||||
tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
|
||||
&enable_pooling_cache);
|
||||
if (enable_pooling_cache) m_pooling_cache_allowed = true;
|
||||
}
|
||||
|
||||
port::Status MIOpenSupport::Init() {
|
||||
@ -844,17 +850,19 @@ class ScopedPoolingDescriptor {
|
||||
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
|
||||
&CheckedNarrowing<int64, int>);
|
||||
|
||||
if (nd != 2) {
|
||||
LOG(FATAL) << "miopen requires pooling dimensions be 2"
|
||||
<< ToString(status);
|
||||
}
|
||||
|
||||
status = wrap::miopenSet2dPoolingDescriptor(
|
||||
status = wrap::miopenSetNdPoolingDescriptor(
|
||||
handle_,
|
||||
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
|
||||
? miopenPoolingMax
|
||||
: miopenPoolingAverage),
|
||||
shape[0], shape[1], padding[0], padding[1], strides[0], strides[1]);
|
||||
nd, shape.data(), padding.data(), strides.data());
|
||||
|
||||
// Note: The index type has to be uint32 type for now because MIOpen
|
||||
// API assumes all input indexes to be the same type. Since a tensor
|
||||
// descriptor can only use int32 type, the index type here need to be
|
||||
// aligned with the tensor index type of the (input) tensor descritptor
|
||||
status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "could not set miopen pooling descriptor: "
|
||||
<< ToString(status);
|
||||
@ -4009,10 +4017,94 @@ bool MIOpenSupport::DoPoolForward(
|
||||
const DeviceMemory<double>& input_data,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
|
||||
LOG(ERROR) << "miopen does not support pooling for dobule type yet";
|
||||
LOG(ERROR) << "miopen does not support pooling for double type yet";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PoolingWorkspaceDescriptor::IsSame(
|
||||
const dnn::BatchDescriptor& input_dimensions,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
|
||||
return dtype == _type &&
|
||||
input_dims ==
|
||||
input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
|
||||
output_dims ==
|
||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
|
||||
op.mode() == pooling_dimensions.mode() &&
|
||||
op.window() == pooling_dimensions.window() &&
|
||||
op.padding() == pooling_dimensions.padding() &&
|
||||
op.strides() == pooling_dimensions.strides();
|
||||
}
|
||||
|
||||
bool PoolingWorkspaceCache::find(
|
||||
const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||
PoolingWorkspaceDescriptor*& pdesc) {
|
||||
pdesc = 0;
|
||||
auto it = cache.find(p);
|
||||
if (it == cache.end()) {
|
||||
return false;
|
||||
}
|
||||
if (!it->second.IsSame(input_dimensions, output_dimensions,
|
||||
pooling_dimensions, _type)) {
|
||||
return false;
|
||||
}
|
||||
pdesc = &it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
void PoolingWorkspaceCache::insert(
|
||||
const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
|
||||
hipStream_t hip_stream) {
|
||||
PoolingWorkspaceDescriptor* desc = 0;
|
||||
auto it = cache.find(p);
|
||||
if (it != cache.end()) {
|
||||
// replacing an entry with the same pointer but different attributes
|
||||
// (if everything matches, the caller is expected to reuse the entry)
|
||||
desc = &it->second;
|
||||
hipStreamSynchronize(hip_stream);
|
||||
memory_used -= desc->workspace_size;
|
||||
} else {
|
||||
cache[p] = PoolingWorkspaceDescriptor();
|
||||
desc = &cache[p];
|
||||
}
|
||||
desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||
desc->output_dims =
|
||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||
desc->op = pooling_dimensions;
|
||||
desc->dtype = _type;
|
||||
desc->timestamp = timestamp;
|
||||
timestamp++;
|
||||
desc->workspace = std::move(workspace);
|
||||
desc->workspace_size = wsp_size;
|
||||
memory_used += wsp_size;
|
||||
trim(hip_stream);
|
||||
}
|
||||
|
||||
void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
|
||||
if (memory_used < memory_budget && cache.size() < trim_size) return;
|
||||
bool must_sync = true;
|
||||
while (true) {
|
||||
int new_size = cache.size() - (cache.size() >> 2);
|
||||
std::vector<const void*> old_entries;
|
||||
for (auto& x : cache)
|
||||
if (x.second.timestamp + new_size < timestamp)
|
||||
old_entries.push_back(x.first);
|
||||
if (old_entries.empty()) break;
|
||||
if (must_sync) hipStreamSynchronize(hip_stream);
|
||||
must_sync = true;
|
||||
for (auto x : old_entries) {
|
||||
memory_used -= cache[x].workspace_size;
|
||||
cache.erase(x);
|
||||
}
|
||||
if (memory_used < memory_budget || cache.size() < 10) break;
|
||||
}
|
||||
}
|
||||
|
||||
bool MIOpenSupport::DoPoolForward(
|
||||
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
||||
const dnn::BatchDescriptor& input_dimensions,
|
||||
@ -4020,7 +4112,6 @@ bool MIOpenSupport::DoPoolForward(
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
// Alpha is the scaling factor for input.
|
||||
float alpha = 1.0;
|
||||
// Beta is the scaling factor for output.
|
||||
@ -4030,10 +4121,48 @@ bool MIOpenSupport::DoPoolForward(
|
||||
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
|
||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
||||
|
||||
bool do_backward = false;
|
||||
uint8* workspace = 0;
|
||||
size_t workspace_size = 0;
|
||||
std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
|
||||
if (m_pooling_cache_enabled) {
|
||||
do_backward = true;
|
||||
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
|
||||
pooling_desc.handle(), dest_desc.handle(), &workspace_size);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to obtain workspace size for backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
if (workspace_size != 0) {
|
||||
PoolingWorkspaceDescriptor* pdesc = 0;
|
||||
bool cache_hit =
|
||||
m_pooling_cache_allowed &&
|
||||
m_pooling_cache.find(input_data.opaque(), input_dimensions,
|
||||
output_dimensions, pooling_dimensions,
|
||||
miopenFloat, pdesc);
|
||||
if (cache_hit) {
|
||||
// reusing the same buffer
|
||||
workspace = reinterpret_cast<uint8*>(
|
||||
pdesc->workspace->mutable_device_memory()->opaque());
|
||||
} else {
|
||||
wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size)
|
||||
.ConsumeValueOrDie();
|
||||
workspace = reinterpret_cast<uint8*>(
|
||||
wsp_mem->mutable_device_memory()->opaque());
|
||||
m_pooling_cache.insert(input_data.opaque(), input_dimensions,
|
||||
output_dimensions, pooling_dimensions,
|
||||
miopenFloat, wsp_mem, workspace_size,
|
||||
AsGpuStreamValue(stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto status = wrap::miopenPoolingForward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
||||
input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
|
||||
false, nullptr, 0);
|
||||
do_backward, workspace, workspace_size);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
|
||||
<< ToString(status);
|
||||
@ -4072,6 +4201,118 @@ bool MIOpenSupport::DoPoolForward(
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool MIOpenSupport::DoPoolBackwardImpl(
|
||||
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
||||
const dnn::BatchDescriptor& input_dimensions,
|
||||
const DeviceMemory<T>& input_data,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const DeviceMemory<T>& output_data, const DeviceMemory<T>& input_diff_data,
|
||||
DeviceMemory<T>* output_diff_data, ScratchAllocator* workspace_allocator) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
|
||||
// Alpha is the scaling factor for input.
|
||||
float alpha = 1.0;
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
auto type =
|
||||
std::is_same<T, float>::value
|
||||
? miopenFloat
|
||||
: (std::is_same<T, Eigen::half>::value ? miopenHalf
|
||||
: (miopenDataType_t)-1);
|
||||
|
||||
ScopedTensorDescriptor src_desc{input_dimensions, type};
|
||||
ScopedTensorDescriptor dest_desc{output_dimensions, type};
|
||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
||||
|
||||
uint8* workspace_ptr = 0;
|
||||
DeviceMemory<uint8> workspace;
|
||||
PoolingWorkspaceDescriptor* pdesc = 0;
|
||||
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
|
||||
pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to obtain workspace size for backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allocate the workspace.
|
||||
if (workspace_size_in_bytes > 0) {
|
||||
bool cache_hit = m_pooling_cache_allowed &&
|
||||
m_pooling_cache.find(input_data.opaque(), input_dimensions,
|
||||
output_dimensions, pooling_dimensions,
|
||||
type, pdesc);
|
||||
if (cache_hit) {
|
||||
assert(pdesc != 0);
|
||||
workspace_ptr = reinterpret_cast<uint8*>(
|
||||
pdesc->workspace->mutable_device_memory()->opaque());
|
||||
VLOG(1) << "Pooling cache hit";
|
||||
} else {
|
||||
VLOG(1) << "Pooling cache miss";
|
||||
assert(workspace_allocator);
|
||||
auto allocated =
|
||||
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||
return false;
|
||||
}
|
||||
DeviceMemory<uint8> dest2; // duplicated dest from forward:
|
||||
int64 dest2_size = 0;
|
||||
|
||||
// miopen requires the strides and dims to be ordered as BDYX.
|
||||
std::vector<int64> dims64 =
|
||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||
// miopen does not use strides and must have 4D tensor.
|
||||
// std::vector<int> dims(pooling_dimensions.ndims() + 2);
|
||||
|
||||
dest2_size = sizeof(T);
|
||||
for (auto& x : dims64) dest2_size *= x;
|
||||
|
||||
if (dest2_size > 0) {
|
||||
assert(workspace_allocator);
|
||||
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
|
||||
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
||||
"backward pooling";
|
||||
}
|
||||
|
||||
status = wrap::miopenPoolingForward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
||||
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
|
||||
workspace.opaque(), workspace_size_in_bytes);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to enqueue forward pooling (before backward) on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
|
||||
}
|
||||
}
|
||||
status = wrap::miopenPoolingBackward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
|
||||
output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
|
||||
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
|
||||
output_diff_data->opaque(), workspace_ptr);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MIOpenSupport::DoPoolBackward(
|
||||
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
|
||||
const dnn::BatchDescriptor& input_dimensions,
|
||||
@ -4094,91 +4335,10 @@ bool MIOpenSupport::DoPoolBackward(
|
||||
const DeviceMemory<float>& input_diff_data,
|
||||
DeviceMemory<float>* output_diff_data,
|
||||
ScratchAllocator* workspace_allocator) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
// Alpha is the scaling factor for input.
|
||||
float alpha = 1.0;
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
|
||||
ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
|
||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
||||
|
||||
DeviceMemory<uint8> workspace;
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
|
||||
&workspace_size_in_bytes);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to obtain workspace size for backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allocate the workspace.
|
||||
if (workspace_size_in_bytes > 0) {
|
||||
assert(workspace_allocator);
|
||||
auto allocated =
|
||||
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMemory<uint8> dest2; // duplicated dest from forward:
|
||||
int dest2_size = 0;
|
||||
|
||||
// miopen requires the strides and dims to be ordered as BDYX.
|
||||
std::vector<int64> dims64 =
|
||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||
|
||||
// miopen does not use strides and must have 4D tensor.
|
||||
std::vector<int> dims(4);
|
||||
|
||||
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
|
||||
&CheckedNarrowing<int64, int>);
|
||||
|
||||
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
|
||||
|
||||
if (dest2_size > 0) {
|
||||
assert(workspace_allocator);
|
||||
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
|
||||
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
||||
"backward pooling";
|
||||
}
|
||||
|
||||
status = wrap::miopenPoolingForward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
||||
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
|
||||
workspace.opaque(), workspace_size_in_bytes);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to enqueue forward pooling (before backward) on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
status = wrap::miopenPoolingBackward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
|
||||
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
|
||||
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
|
||||
output_diff_data->opaque(), workspace.opaque());
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
|
||||
input_data, output_dimensions, output_data,
|
||||
input_diff_data, output_diff_data,
|
||||
workspace_allocator);
|
||||
}
|
||||
|
||||
bool MIOpenSupport::DoPoolBackward(
|
||||
@ -4190,91 +4350,10 @@ bool MIOpenSupport::DoPoolBackward(
|
||||
const DeviceMemory<Eigen::half>& input_diff_data,
|
||||
DeviceMemory<Eigen::half>* output_diff_data,
|
||||
ScratchAllocator* workspace_allocator) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
// Alpha is the scaling factor for input.
|
||||
float alpha = 1.0;
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
|
||||
ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
|
||||
ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
|
||||
|
||||
DeviceMemory<uint8> workspace;
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
|
||||
&workspace_size_in_bytes);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to obtain workspace size for backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allocate the workspace.
|
||||
if (workspace_size_in_bytes > 0) {
|
||||
assert(workspace_allocator);
|
||||
auto allocated =
|
||||
workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||
if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMemory<uint8> dest2; // duplicated dest from forward:
|
||||
int dest2_size = 0;
|
||||
|
||||
// miopen requires the strides and dims to be ordered as BDYX.
|
||||
std::vector<int64> dims64 =
|
||||
output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
|
||||
|
||||
// miopen does not use strides and must have 4D tensor.
|
||||
std::vector<int> dims(4);
|
||||
|
||||
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
|
||||
&CheckedNarrowing<int64, int>);
|
||||
|
||||
dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
|
||||
|
||||
if (dest2_size > 0) {
|
||||
assert(workspace_allocator);
|
||||
auto allocated = workspace_allocator->AllocateBytes(dest2_size);
|
||||
if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "Failed to allocate backward pooling workspace";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
||||
"backward pooling";
|
||||
}
|
||||
|
||||
status = wrap::miopenPoolingForward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
|
||||
input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
|
||||
workspace.opaque(), workspace_size_in_bytes);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR)
|
||||
<< "failed to enqueue forward pooling (before backward) on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
status = wrap::miopenPoolingBackward(
|
||||
miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
|
||||
dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
|
||||
src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
|
||||
output_diff_data->opaque(), workspace.opaque());
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
|
||||
input_data, output_dimensions, output_data,
|
||||
input_diff_data, output_diff_data,
|
||||
workspace_allocator);
|
||||
}
|
||||
|
||||
bool MIOpenSupport::DoNormalizeWithDimensions(
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "rocm/include/miopen/miopen.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
@ -38,6 +39,39 @@ class MIOpenCTCLossDescriptor;
|
||||
// Opaque and unique identifier for the MIOpen plugin.
|
||||
extern const PluginId kMIOpenPlugin;
|
||||
|
||||
struct PoolingWorkspaceDescriptor {
|
||||
std::vector<int64> input_dims;
|
||||
std::vector<int64> output_dims;
|
||||
dnn::PoolingDescriptor op;
|
||||
int dtype;
|
||||
uint64_t timestamp;
|
||||
std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
|
||||
size_t workspace_size;
|
||||
bool IsSame(const dnn::BatchDescriptor& input_dimensions,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions, int _type);
|
||||
};
|
||||
|
||||
struct PoolingWorkspaceCache {
|
||||
std::map<const void*, PoolingWorkspaceDescriptor> cache;
|
||||
const int trim_size = 1000;
|
||||
const uint64_t memory_budget = 2e7;
|
||||
uint64_t timestamp = 0;
|
||||
uint64_t memory_used = 0;
|
||||
bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||
PoolingWorkspaceDescriptor*& pdesc);
|
||||
void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions, int _type,
|
||||
std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
|
||||
size_t wsp_size, hipStream_t hip_stream);
|
||||
|
||||
private:
|
||||
void trim(hipStream_t hip_stream);
|
||||
};
|
||||
|
||||
// miopen-library based DNN support. For details on overridden interface
|
||||
// functions, see dnn.h.
|
||||
class MIOpenSupport : public dnn::DnnSupport {
|
||||
@ -664,6 +698,10 @@ class MIOpenSupport : public dnn::DnnSupport {
|
||||
// Provide access to the MIOpen handle.
|
||||
std::unique_ptr<class MIOpenAccess> miopen_;
|
||||
|
||||
PoolingWorkspaceCache m_pooling_cache;
|
||||
bool m_pooling_cache_allowed = false;
|
||||
bool m_pooling_cache_enabled = false;
|
||||
|
||||
template <class T, class U>
|
||||
bool DoBatchNormalizationForwardImpl(
|
||||
Stream* stream, dnn::DataType input_data_type,
|
||||
@ -847,6 +885,36 @@ class MIOpenSupport : public dnn::DnnSupport {
|
||||
ScratchAllocator* scratch_allocator,
|
||||
std::vector<dnn::ProfileResult>* out_algorithms);
|
||||
|
||||
port::Status DoCtcLossImpl(
|
||||
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
|
||||
port::Status DoPrepareForCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
|
||||
template <class T>
|
||||
bool DoPoolBackwardImpl(Stream* stream,
|
||||
const dnn::PoolingDescriptor& pooling_dimensions,
|
||||
const dnn::BatchDescriptor& input_dimensions,
|
||||
const DeviceMemory<T>& input_data,
|
||||
const dnn::BatchDescriptor& output_dimensions,
|
||||
const DeviceMemory<T>& output_data,
|
||||
const DeviceMemory<T>& input_diff_data,
|
||||
DeviceMemory<T>* output_diff_data,
|
||||
ScratchAllocator* workspace_allocator = nullptr);
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user