Enable depthwise convs in auto_mixed_precision
- These are well-supported as of CUDNN v8. - Also adds a Python test.
This commit is contained in:
parent
c60a66c4ee
commit
e481c700c8
tensorflow
@ -127,11 +127,6 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
"GRUBlockCellGrad",
|
||||
"LSTMBlockCell",
|
||||
"LSTMBlockCellGrad",
|
||||
// TODO(benbarsdell): Enable these when fast and safe fp16 kernels are
|
||||
// available for depthwise convolutions.
|
||||
// "DepthwiseConv2dNative",
|
||||
// "DepthwiseConv2dNativeBackpropFilter",
|
||||
// "DepthwiseConv2dNativeBackpropInput",
|
||||
"MatMul",
|
||||
};
|
||||
if (cuda_version_ >= 9010) {
|
||||
@ -147,6 +142,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
list.insert("Conv3DBackpropInput");
|
||||
list.insert("Conv3DBackpropInputV2");
|
||||
}
|
||||
if (cudnn_version_ >= 8000) {
|
||||
list.insert("DepthwiseConv2dNative");
|
||||
list.insert("DepthwiseConv2dNativeBackpropFilter");
|
||||
list.insert("DepthwiseConv2dNativeBackpropInput");
|
||||
}
|
||||
UpdateList("ALLOWLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
|
@ -46,6 +46,7 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import sysconfig
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import gradient_descent
|
||||
@ -138,6 +139,11 @@ def _conv_pool(x):
|
||||
return h_pool2
|
||||
|
||||
|
||||
def _depthwise_conv2d(x, w):
|
||||
"""Returns a 2d depthwise convolution layer with full stride."""
|
||||
return nn.depthwise_conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
|
||||
|
||||
|
||||
def _simple_loop(x, functor):
|
||||
"""Simple loop whose body is provided by the functor."""
|
||||
init = (constant_op.constant(0), x)
|
||||
@ -566,6 +572,42 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
tol = 5e-3 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
# TODO(benbarsdell): This test has not been tried with MKL.
|
||||
@parameterized.parameters(['cuda'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_depthwise_conv2d(self, mode):
|
||||
"""Test grad ops with depthwise convolution2d graph."""
|
||||
self._maybe_skip(mode)
|
||||
cudnn_version_str = sysconfig.get_build_info().get('cudnn_version', '0.0')
|
||||
cudnn_version = tuple([int(x) for x in cudnn_version_str.split('.')])
|
||||
if cudnn_version < (8,):
|
||||
# Depthwise conv2d ops are only enabled in auto_mixed_precision as of
|
||||
# cuDNN v8.
|
||||
self.skipTest('cuDNN version >= 8 required')
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
f = _weight([3, 3, 1, 4])
|
||||
y = _depthwise_conv2d(x, f)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x, f])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_f16(mode, node_map, 'depthwise')
|
||||
self._assert_output_f16(
|
||||
mode, node_map,
|
||||
'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropInput')
|
||||
self._assert_output_f16(
|
||||
mode, node_map,
|
||||
'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropFilter')
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
tol = 2e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
|
Loading…
Reference in New Issue
Block a user