Improve depthwise_conv_op_test.py.
Before, most the forward pass tests didn't actually test anything. It would compare the outputs of depthwise_conv2d and depthwise_conv2d_native to make sure they were the same. However, depthwise_conv2d simply forwards to depthwise_conv2d_native, adding dilations to the filter. However, no test used dilations, making depthwise_conv2d and depthwise_conv2d_native equivalent. I changed the test to add a Numpy implementation of depthwise convolutions, and compare depthwise_conv2d to the Numpy version. I additionally added dilations tests. I also made some refactors to make the code clearer. PiperOrigin-RevId: 300607077 Change-Id: Ib4fcdee02f633a62b884d81dd5772ecd42e2258f
This commit is contained in:
parent
46b7331fb1
commit
3511dbe928
@ -34,31 +34,136 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
|
def _DepthwiseConv2dNumpyBasic(x1, x2, strides):
|
||||||
|
"""Compute depthwise_conv2d using Numpy.
|
||||||
|
|
||||||
|
This allows use to test TensorFlow's depthwise_conv2d by comparing to the
|
||||||
|
Numpy version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x1: The input Numpy array, in NHWC format.
|
||||||
|
x2: The filter Numpy array.
|
||||||
|
strides: A Python list of 4 elements representing the strides.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The depthwise conv2d output as a Numpy array.
|
||||||
|
"""
|
||||||
|
n, h, w, c = x1.shape
|
||||||
|
fh, fw, c2, o = x2.shape
|
||||||
|
assert c == c2
|
||||||
|
_, sh, sw, _ = strides
|
||||||
|
out_rows = (h - fh + sh) // sh
|
||||||
|
out_cols = (w - fw + sw) // sw
|
||||||
|
out = np.zeros([n, out_rows, out_cols, c * o])
|
||||||
|
for i in range(out_rows):
|
||||||
|
for j in range(out_cols):
|
||||||
|
for k in range(c):
|
||||||
|
start_height = i * sh
|
||||||
|
end_height = start_height + fh
|
||||||
|
start_width = j * sw
|
||||||
|
end_width = start_width + fw
|
||||||
|
# multiplied_slice.shape: (b, fh, fw, o)
|
||||||
|
multiplied_slice = (
|
||||||
|
x1[:, start_height:end_height, start_width:end_width, k, np.newaxis]
|
||||||
|
* x2[:, :, k, :])
|
||||||
|
# Set a slice of b * o elements of 'out'.
|
||||||
|
out[:, i, j, k * o:(k + 1) * o] = np.sum(multiplied_slice, axis=(1, 2))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _DepthwiseConv2dNumpy(x1, x2, strides, padding, data_format, dilations):
|
||||||
|
"""Compute depthwise_conv2d using Numpy.
|
||||||
|
|
||||||
|
This allows use to test TensorFlow's depthwise_conv2d by comparing to the
|
||||||
|
Numpy version.
|
||||||
|
|
||||||
|
Unlike `_DepthwiseConv2dNumpyBasic`, this supports more advanced features
|
||||||
|
like padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x1: The input Numpy array.
|
||||||
|
x2: The filter Numpy array.
|
||||||
|
strides: A Python list of 4 elements representing the strides.
|
||||||
|
padding: The padding. "SAME", "VALID", or a list of explicit paddings.
|
||||||
|
data_format: "NHWC" or "NCHW".
|
||||||
|
dilations: A list of 2 elements, representing the dilations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The depthwise conv2d as a Numpy array.
|
||||||
|
"""
|
||||||
|
if data_format == "NCHW":
|
||||||
|
# Transpose arguments to NHWC format.
|
||||||
|
x1 = np.transpose(x1, (0, 3, 1, 2))
|
||||||
|
strides = [strides[0], strides[3], strides[1], strides[2]]
|
||||||
|
if dilations:
|
||||||
|
dilations = [dilations[0], dilations[3], dilations[1], dilations[2]]
|
||||||
|
|
||||||
|
if dilations:
|
||||||
|
# Dilate the filter so _DepthwiseConv2dNumpyBasic doesn't have to deal with
|
||||||
|
# dilations.
|
||||||
|
fh, fw, c, o = x2.shape
|
||||||
|
new_fh = (fh - 1) * dilations[0] + 1
|
||||||
|
new_fw = (fw - 1) * dilations[1] + 1
|
||||||
|
new_x2 = np.zeros((new_fh, new_fw, c, o))
|
||||||
|
for i in range(fh):
|
||||||
|
for j in range(fw):
|
||||||
|
new_x2[i * dilations[0], j * dilations[1], : :] = x2[i, j, :, :]
|
||||||
|
x2 = new_x2
|
||||||
|
|
||||||
|
# Pad input so _DepthwiseConv2dNumpyBasic doesn't have to deal with padding.
|
||||||
|
if padding == "SAME":
|
||||||
|
def PaddingsForDim(input_dim, filter_dim, stride):
|
||||||
|
"""Computes paddings for a single dimension."""
|
||||||
|
if input_dim % stride == 0:
|
||||||
|
total_padding = max(filter_dim - stride, 0)
|
||||||
|
else:
|
||||||
|
total_padding = max(filter_dim - (input_dim % stride), 0)
|
||||||
|
pad_before = total_padding // 2
|
||||||
|
pad_after = total_padding - pad_before
|
||||||
|
return pad_before, pad_after
|
||||||
|
padding = [(0, 0),
|
||||||
|
PaddingsForDim(x1.shape[1], x2.shape[0], strides[1]),
|
||||||
|
PaddingsForDim(x1.shape[2], x2.shape[1], strides[2]),
|
||||||
|
(0, 0)]
|
||||||
|
elif padding == "VALID":
|
||||||
|
padding = [(0, 0)] * 4
|
||||||
|
x1 = np.pad(x1, padding, "constant")
|
||||||
|
|
||||||
|
y = _DepthwiseConv2dNumpyBasic(x1, x2, strides)
|
||||||
|
|
||||||
|
if data_format == "NCHW":
|
||||||
|
# Transpose back to NCHW format.
|
||||||
|
y = np.transpose(y, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def ConfigsToTest():
|
def ConfigsToTest():
|
||||||
"""Iterator for different convolution shapes, strides and paddings.
|
"""Iterator for different convolution shapes, strides and paddings.
|
||||||
|
|
||||||
Yields:
|
Returns:
|
||||||
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
|
List of tuples (input_size, filter_size, out_size, stride, padding,
|
||||||
convolution parameters.
|
dilations), the depthwise convolution parameters.
|
||||||
"""
|
"""
|
||||||
input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
|
def Config(input_size, filter_size, out_size, stride=1, padding="SAME",
|
||||||
[4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
|
dilations=None):
|
||||||
[3, 299, 299, 3], [5, 183, 183, 1]]
|
return input_size, filter_size, out_size, stride, padding, dilations
|
||||||
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
|
return [
|
||||||
[3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
|
Config([4, 5, 5, 48], [1, 1, 48, 2], [4, 5, 5, 96]),
|
||||||
8], [5, 5, 1, 2]]
|
Config([4, 8, 8, 84], [1, 3, 84, 1], [4, 8, 8, 84]),
|
||||||
out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
|
Config([4, 17, 17, 48], [3, 1, 48, 4], [4, 17, 17, 192]),
|
||||||
[4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
|
Config([4, 9, 27, 8], [3, 3, 8, 1], [4, 9, 27, 8]),
|
||||||
[3, 150, 150, 24], [5, 92, 92, 2]]
|
Config([4, 31, 31, 7], [3, 3, 7, 1], [4, 31, 31, 7]),
|
||||||
strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
|
Config([4, 35, 35, 2], [5, 5, 2, 1], [4, 35, 35, 2]),
|
||||||
# pylint: disable=invalid-name
|
Config([4, 147, 147, 2], [3, 3, 2, 8], [4, 49, 49, 16], 3,
|
||||||
VALID = "VALID"
|
padding="VALID"),
|
||||||
SAME = "SAME"
|
Config([3, 299, 299, 3], [3, 2, 3, 8], [3, 150, 150, 24], 2),
|
||||||
# pylint: enable=invalid-name
|
Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 92, 92, 2], 2),
|
||||||
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
|
Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 183, 183, 2],
|
||||||
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
|
dilations=[2, 2]),
|
||||||
paddings):
|
Config([5, 41, 35, 2], [4, 7, 2, 2], [5, 32, 23, 4], padding="VALID",
|
||||||
yield i, f, o, s, p
|
dilations=[3, 2]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def CheckGradConfigsToTest():
|
def CheckGradConfigsToTest():
|
||||||
@ -67,34 +172,26 @@ def CheckGradConfigsToTest():
|
|||||||
compute_gradient_error() is very expensive. So the configs should be
|
compute_gradient_error() is very expensive. So the configs should be
|
||||||
relatively small.
|
relatively small.
|
||||||
|
|
||||||
Yields:
|
Returns:
|
||||||
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
|
List of tuples (input_size, filter_size, out_size, stride, padding,
|
||||||
convolution parameters.
|
dilations), the depthwise convolution parameters.
|
||||||
"""
|
"""
|
||||||
input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2],
|
def Config(input_size, filter_size, out_size, stride=1, padding="SAME",
|
||||||
[2, 15, 16, 1]]
|
dilations=None):
|
||||||
filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1],
|
return input_size, filter_size, out_size, stride, padding, dilations
|
||||||
[3, 3, 1, 2]]
|
return [
|
||||||
out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2],
|
Config([2, 5, 8, 1], [4, 4, 1, 2], [2, 5, 8, 2]),
|
||||||
[2, 5, 5, 2]]
|
Config([4, 5, 5, 1], [2, 2, 1, 2], [4, 2, 2, 2], 2, padding="VALID"),
|
||||||
strides = [1, 2, 1, 1, 3]
|
Config([2, 4, 4, 2], [3, 1, 2, 2], [2, 4, 4, 4]),
|
||||||
# pylint: disable=invalid-name
|
Config([1, 15, 15, 2], [1, 3, 2, 1], [1, 15, 15, 2]),
|
||||||
VALID = "VALID"
|
Config([2, 15, 16, 1], [3, 3, 1, 2], [2, 5, 5, 2], 3, padding="VALID"),
|
||||||
SAME = "SAME"
|
Config([2, 5, 8, 1], [4, 3, 1, 2], [2, 5, 8, 2], dilations=[1, 2]),
|
||||||
# pylint: enable=invalid-name
|
]
|
||||||
paddings = [SAME, VALID, SAME, SAME, VALID]
|
|
||||||
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
|
|
||||||
paddings):
|
|
||||||
yield i, f, o, s, p
|
|
||||||
|
|
||||||
|
|
||||||
class DepthwiseConv2DTest(test.TestCase):
|
class DepthwiseConv2DTest(test.TestCase):
|
||||||
|
|
||||||
# This is testing that depthwise_conv2d and depthwise_conv2d_native
|
# This tests depthwise_conv2d and depthwise_conv2d_native
|
||||||
# produce the same results. It also tests that NCHW and NHWC
|
|
||||||
# formats agree, by comparing the depthwise_conv2d_native with
|
|
||||||
# 'NCHW' format (with transposition) matches the 'NHWC' format using
|
|
||||||
# the higher level interface.
|
|
||||||
def _VerifyValues(self,
|
def _VerifyValues(self,
|
||||||
tensor_in_sizes,
|
tensor_in_sizes,
|
||||||
filter_in_sizes,
|
filter_in_sizes,
|
||||||
@ -103,7 +200,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
use_gpu,
|
use_gpu,
|
||||||
grouped_conv=False,
|
grouped_conv=False,
|
||||||
data_format="NHWC"):
|
data_format="NHWC",
|
||||||
|
dilations=None):
|
||||||
"""Verifies the output values of the convolution function.
|
"""Verifies the output values of the convolution function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -117,6 +215,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
use_gpu: Whether to use GPU.
|
use_gpu: Whether to use GPU.
|
||||||
grouped_conv: Whether to use cuDNN 7's grouped convolution.
|
grouped_conv: Whether to use cuDNN 7's grouped convolution.
|
||||||
data_format: The data_format of the input. "NHWC" or "NCHW".
|
data_format: The data_format of the input. "NHWC" or "NCHW".
|
||||||
|
dilations: A list of 2 elements, representing the dilations.
|
||||||
"""
|
"""
|
||||||
input_size = 1
|
input_size = 1
|
||||||
filter_size = 1
|
filter_size = 1
|
||||||
@ -126,7 +225,14 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
filter_size *= s
|
filter_size *= s
|
||||||
# Initializes the input and filter tensor with numbers incrementing from 1.
|
# Initializes the input and filter tensor with numbers incrementing from 1.
|
||||||
x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
|
x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
|
||||||
|
x1 = np.array(x1).reshape(tensor_in_sizes)
|
||||||
x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
|
x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
|
||||||
|
x2 = np.array(x2).reshape(filter_in_sizes)
|
||||||
|
# Compute reference result
|
||||||
|
strides = [1, stride, stride, 1]
|
||||||
|
np_result = _DepthwiseConv2dNumpy(x1, x2, strides, padding, "NHWC",
|
||||||
|
dilations)
|
||||||
|
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
with self.session(graph=graph, use_gpu=use_gpu) as sess:
|
with self.session(graph=graph, use_gpu=use_gpu) as sess:
|
||||||
@ -137,60 +243,62 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
}[data_type]
|
}[data_type]
|
||||||
|
|
||||||
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
|
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
|
||||||
t1.set_shape(tensor_in_sizes)
|
|
||||||
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=data_type)
|
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=data_type)
|
||||||
|
|
||||||
native_t1 = t1
|
|
||||||
strides = [1, stride, stride, 1]
|
|
||||||
if data_format == "NCHW":
|
if data_format == "NCHW":
|
||||||
# Transpose from NHWC input to NCHW
|
# Transpose from NHWC input to NCHW
|
||||||
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
|
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
|
||||||
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
|
t1 = array_ops.transpose(t1, [0, 3, 1, 2])
|
||||||
strides = [1, 1, stride, stride]
|
strides = [1, 1, stride, stride]
|
||||||
|
|
||||||
with sess.graph._kernel_label_map({
|
# depthwise_conv2d_native does not support dilations except on TPUs.
|
||||||
"DepthwiseConv2dNative": "cudnn_grouped_convolution"
|
if dilations is None:
|
||||||
} if grouped_conv else {}):
|
with sess.graph._kernel_label_map({
|
||||||
conv_native = nn_ops.depthwise_conv2d_native(
|
"DepthwiseConv2dNative": "cudnn_grouped_convolution"
|
||||||
native_t1,
|
} if grouped_conv else {}):
|
||||||
t2,
|
conv_native = nn_ops.depthwise_conv2d_native(
|
||||||
strides=strides,
|
t1,
|
||||||
data_format=data_format,
|
t2,
|
||||||
padding=padding)
|
strides=strides,
|
||||||
|
data_format=data_format,
|
||||||
|
padding=padding)
|
||||||
|
|
||||||
if data_format == "NCHW":
|
if data_format == "NCHW":
|
||||||
# Transpose back from NCHW to NHWC
|
# Transpose back from NCHW to NHWC
|
||||||
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
|
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
native_result = self.evaluate(conv_native)
|
# The Numpy array from calling depthwise_conv2d_native
|
||||||
except errors.InvalidArgumentError as e:
|
native_result = self.evaluate(conv_native)
|
||||||
# Grouped convolution kernel is only registered for cuDNN 7. Silently
|
except errors.InvalidArgumentError as e:
|
||||||
# return when we are running on an earlier version or without GPU.
|
# Grouped convolution kernel is only registered for cuDNN 7. Silently
|
||||||
if e.message.startswith(
|
# return when we are running on an earlier version or without GPU.
|
||||||
"No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
|
if e.message.startswith(
|
||||||
tf_logging.warn("Skipping grouped convolution test")
|
"No OpKernel was registered to support Op "
|
||||||
return
|
"'DepthwiseConv2dNative'"):
|
||||||
raise e
|
tf_logging.warn("Skipping grouped convolution test")
|
||||||
|
return
|
||||||
|
raise e
|
||||||
|
|
||||||
conv_interface = nn_impl.depthwise_conv2d(
|
conv_interface = nn_impl.depthwise_conv2d(
|
||||||
t1, t2, strides=[1, stride, stride, 1], padding=padding)
|
t1, t2, strides=strides, padding=padding,
|
||||||
|
data_format=data_format, dilations=dilations)
|
||||||
|
if data_format == "NCHW":
|
||||||
|
# Transpose back from NCHW to NHWC
|
||||||
|
conv_interface = array_ops.transpose(conv_interface, [0, 2, 3, 1])
|
||||||
|
|
||||||
|
# The Numpy array from calling depthwise_conv2d
|
||||||
interface_result = self.evaluate(conv_interface)
|
interface_result = self.evaluate(conv_interface)
|
||||||
|
|
||||||
tf_logging.info(
|
if dilations is None:
|
||||||
"data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f",
|
self.assertAllClose(native_result, np_result, atol=tolerance, rtol=0.)
|
||||||
data_type, use_gpu, grouped_conv,
|
self.assertAllClose(interface_result, np_result, atol=tolerance, rtol=0.)
|
||||||
np.amax(np.absolute(native_result - interface_result)))
|
|
||||||
self.assertArrayNear(
|
|
||||||
np.ravel(native_result), np.ravel(interface_result), tolerance)
|
|
||||||
self.assertShapeEqual(native_result, conv_native)
|
|
||||||
self.assertShapeEqual(native_result, conv_interface)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
@test_util.run_cuda_only
|
@test_util.run_cuda_only
|
||||||
def testDepthwiseConv2DCudnn(self):
|
def testDepthwiseConv2DCudnn(self):
|
||||||
for index, (input_size, filter_size, _, stride,
|
for index, (input_size, filter_size, _, stride,
|
||||||
padding) in enumerate(ConfigsToTest()):
|
padding, dilations) in enumerate(ConfigsToTest()):
|
||||||
# The CuDNN depthwise conv is turned on only when input/output is NCHW and
|
# The CuDNN depthwise conv is turned on only when input/output is NCHW and
|
||||||
# float16(half). See cudnn release note 7.6.3.
|
# float16(half). See cudnn release note 7.6.3.
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
@ -204,12 +312,13 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
padding,
|
padding,
|
||||||
data_type,
|
data_type,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NCHW")
|
data_format="NCHW",
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDepthwiseConv2D(self):
|
def testDepthwiseConv2D(self):
|
||||||
for index, (input_size, filter_size, _, stride,
|
for index, (input_size, filter_size, _, stride,
|
||||||
padding) in enumerate(ConfigsToTest()):
|
padding, dilations) in enumerate(ConfigsToTest()):
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
|
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
|
||||||
"%s", index, input_size, filter_size, stride, padding)
|
"%s", index, input_size, filter_size, stride, padding)
|
||||||
@ -219,7 +328,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
for data_type in ([dtypes.float32] + optional_float64):
|
for data_type in ([dtypes.float32] + optional_float64):
|
||||||
tf_logging.info("Testing without grouped_conv")
|
tf_logging.info("Testing without grouped_conv")
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
input_size, filter_size, stride, padding, data_type, use_gpu=True)
|
input_size, filter_size, stride, padding, data_type, use_gpu=True,
|
||||||
|
dilations=dilations)
|
||||||
tf_logging.info("Testing with grouped_conv")
|
tf_logging.info("Testing with grouped_conv")
|
||||||
self._VerifyValues(
|
self._VerifyValues(
|
||||||
input_size,
|
input_size,
|
||||||
@ -228,7 +338,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
padding,
|
padding,
|
||||||
data_type,
|
data_type,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
grouped_conv=True)
|
grouped_conv=True,
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDepthwiseConv2DWithUnknownShape(self):
|
def testDepthwiseConv2DWithUnknownShape(self):
|
||||||
@ -250,7 +361,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for index, (input_size, filter_size, _, stride,
|
for index, (input_size, filter_size, _, stride,
|
||||||
padding) in enumerate(ConfigsToTest()):
|
padding, dilations) in enumerate(ConfigsToTest()):
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
|
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
|
||||||
"padding: %s", index, input_size, filter_size, stride, padding)
|
"padding: %s", index, input_size, filter_size, stride, padding)
|
||||||
@ -265,7 +376,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
padding,
|
padding,
|
||||||
data_type,
|
data_type,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NCHW")
|
data_format="NCHW",
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
# This is testing against hand calculated results.
|
# This is testing against hand calculated results.
|
||||||
|
|
||||||
@ -385,7 +497,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
test_input,
|
test_input,
|
||||||
use_gpu,
|
use_gpu,
|
||||||
grouped_conv=False,
|
grouped_conv=False,
|
||||||
data_format="NHWC"):
|
data_format="NHWC",
|
||||||
|
dilations=None):
|
||||||
input_size = 1
|
input_size = 1
|
||||||
for x in input_shape:
|
for x in input_shape:
|
||||||
input_size *= x
|
input_size *= x
|
||||||
@ -393,7 +506,9 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
for x in filter_shape:
|
for x in filter_shape:
|
||||||
filter_size *= x
|
filter_size *= x
|
||||||
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
|
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
|
||||||
|
input_np = np.array(input_data).reshape(input_shape)
|
||||||
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
|
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
|
||||||
|
filter_np = np.array(filter_data).reshape(filter_shape)
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
with self.session(graph=graph, use_gpu=use_gpu) as sess:
|
with self.session(graph=graph, use_gpu=use_gpu) as sess:
|
||||||
@ -404,9 +519,9 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
}[data_type]
|
}[data_type]
|
||||||
|
|
||||||
input_tensor = constant_op.constant(
|
input_tensor = constant_op.constant(
|
||||||
input_data, shape=input_shape, dtype=data_type, name="input")
|
input_np, shape=input_shape, dtype=data_type, name="input")
|
||||||
filter_tensor = constant_op.constant(
|
filter_tensor = constant_op.constant(
|
||||||
filter_data, shape=filter_shape, dtype=data_type, name="filter")
|
filter_np, shape=filter_shape, dtype=data_type, name="filter")
|
||||||
|
|
||||||
native_input = input_tensor
|
native_input = input_tensor
|
||||||
strides = [1, stride, stride, 1]
|
strides = [1, stride, stride, 1]
|
||||||
@ -427,12 +542,13 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
"DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
|
"DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
|
||||||
"DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
|
"DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
|
||||||
} if grouped_conv else {}):
|
} if grouped_conv else {}):
|
||||||
depthwise_conv2d = nn_ops.depthwise_conv2d_native(
|
depthwise_conv2d = nn_impl.depthwise_conv2d(
|
||||||
native_input,
|
native_input,
|
||||||
filter_tensor,
|
filter_tensor,
|
||||||
strides,
|
strides,
|
||||||
padding,
|
padding,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
|
dilations=dilations,
|
||||||
name="depthwise_conv2d")
|
name="depthwise_conv2d")
|
||||||
|
|
||||||
self.assertEqual(output_shape, depthwise_conv2d.get_shape())
|
self.assertEqual(output_shape, depthwise_conv2d.get_shape())
|
||||||
@ -462,7 +578,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
@test_util.run_cuda_only
|
@test_util.run_cuda_only
|
||||||
def testDepthwiseConv2DInputGradCudnn(self):
|
def testDepthwiseConv2DInputGradCudnn(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding, dilations) in enumerate(CheckGradConfigsToTest()):
|
||||||
# The CuDNN depthwise conv (input gradient) is turned on only when
|
# The CuDNN depthwise conv (input gradient) is turned on only when
|
||||||
# stride = 1, input/output is NCHW and float16(half). See cudnn release
|
# stride = 1, input/output is NCHW and float16(half). See cudnn release
|
||||||
# note 7.6.3.
|
# note 7.6.3.
|
||||||
@ -482,12 +598,13 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
test_input=True,
|
test_input=True,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NCHW")
|
data_format="NCHW",
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDepthwiseConv2DInputGrad(self):
|
def testDepthwiseConv2DInputGrad(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding, dilations) in enumerate(CheckGradConfigsToTest()):
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
|
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
|
||||||
"padding: %s", index, input_size, filter_size, stride, padding)
|
"padding: %s", index, input_size, filter_size, stride, padding)
|
||||||
@ -503,7 +620,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
padding,
|
padding,
|
||||||
data_type,
|
data_type,
|
||||||
test_input=True,
|
test_input=True,
|
||||||
use_gpu=True)
|
use_gpu=True,
|
||||||
|
dilations=dilations)
|
||||||
self._ConstructAndTestGradient(
|
self._ConstructAndTestGradient(
|
||||||
input_size,
|
input_size,
|
||||||
filter_size,
|
filter_size,
|
||||||
@ -513,7 +631,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
test_input=True,
|
test_input=True,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
grouped_conv=True)
|
grouped_conv=True,
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDepthwiseConv2DInputGradFormat(self):
|
def testDepthwiseConv2DInputGradFormat(self):
|
||||||
@ -521,7 +640,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding, dilations) in enumerate(CheckGradConfigsToTest()):
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
|
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
|
||||||
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
||||||
@ -539,13 +658,14 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
test_input=True,
|
test_input=True,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NCHW")
|
data_format="NCHW",
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
@test_util.run_cuda_only
|
@test_util.run_cuda_only
|
||||||
def testDepthwiseConv2DFilterGradCudnn(self):
|
def testDepthwiseConv2DFilterGradCudnn(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding, dilations) in enumerate(CheckGradConfigsToTest()):
|
||||||
# The CuDNN depthwise conv (filter gradient) is turned on only when
|
# The CuDNN depthwise conv (filter gradient) is turned on only when
|
||||||
# input/output is float16(half). See cudnn release note 7.6.3.
|
# input/output is float16(half). See cudnn release note 7.6.3.
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
@ -562,7 +682,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
test_input=False,
|
test_input=False,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NCHW")
|
data_format="NCHW",
|
||||||
|
dilations=dilations)
|
||||||
self._ConstructAndTestGradient(
|
self._ConstructAndTestGradient(
|
||||||
input_size,
|
input_size,
|
||||||
filter_size,
|
filter_size,
|
||||||
@ -572,12 +693,13 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
test_input=False,
|
test_input=False,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NHWC")
|
data_format="NHWC",
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDepthwiseConv2DFilterGrad(self):
|
def testDepthwiseConv2DFilterGrad(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding, dilations) in enumerate(CheckGradConfigsToTest()):
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
|
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
|
||||||
"%d, padding: %s", index, input_size, filter_size, stride, padding)
|
"%d, padding: %s", index, input_size, filter_size, stride, padding)
|
||||||
@ -593,7 +715,8 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
padding,
|
padding,
|
||||||
data_type,
|
data_type,
|
||||||
test_input=False,
|
test_input=False,
|
||||||
use_gpu=True)
|
use_gpu=True,
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDepthwiseConv2DFilterGradFormat(self):
|
def testDepthwiseConv2DFilterGradFormat(self):
|
||||||
@ -601,7 +724,7 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(CheckGradConfigsToTest()):
|
padding, dilations) in enumerate(CheckGradConfigsToTest()):
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
|
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
|
||||||
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
||||||
@ -619,32 +742,13 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
data_type,
|
data_type,
|
||||||
test_input=False,
|
test_input=False,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format="NCHW")
|
data_format="NCHW",
|
||||||
|
dilations=dilations)
|
||||||
|
|
||||||
def _CompareBackpropInputFloat(self, input_sizes, filter_sizes, output_sizes,
|
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
|
||||||
stride, padding):
|
stride, padding, dtype):
|
||||||
x1 = np.random.rand(*filter_sizes).astype(np.float32)
|
x1 = np.random.rand(*filter_sizes).astype(dtype)
|
||||||
x2 = np.random.rand(*output_sizes).astype(np.float32)
|
x2 = np.random.rand(*output_sizes).astype(dtype)
|
||||||
|
|
||||||
def _GetVal(use_gpu):
|
|
||||||
with self.cached_session(use_gpu=use_gpu):
|
|
||||||
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
|
|
||||||
t1 = constant_op.constant(x1, shape=filter_sizes)
|
|
||||||
t2 = constant_op.constant(x2, shape=output_sizes)
|
|
||||||
backprop = nn_ops.depthwise_conv2d_native_backprop_input(
|
|
||||||
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
|
|
||||||
ret = self.evaluate(backprop)
|
|
||||||
self.assertShapeEqual(ret, backprop)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
gpu_value = _GetVal(use_gpu=True)
|
|
||||||
cpu_value = _GetVal(use_gpu=False)
|
|
||||||
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
|
|
||||||
|
|
||||||
def _CompareBackpropInputDouble(self, input_sizes, filter_sizes, output_sizes,
|
|
||||||
stride, padding):
|
|
||||||
x1 = np.random.rand(*filter_sizes).astype(np.float64)
|
|
||||||
x2 = np.random.rand(*output_sizes).astype(np.float64)
|
|
||||||
|
|
||||||
def _GetVal(use_gpu):
|
def _GetVal(use_gpu):
|
||||||
with self.cached_session(use_gpu=use_gpu):
|
with self.cached_session(use_gpu=use_gpu):
|
||||||
@ -663,44 +767,26 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
|
|
||||||
def testDepthwiseConv2DInputGradCompare(self):
|
def testDepthwiseConv2DInputGradCompare(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(ConfigsToTest()):
|
padding, dilations) in enumerate(ConfigsToTest()):
|
||||||
|
if dilations:
|
||||||
|
continue
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
|
"Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
|
||||||
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
||||||
padding)
|
padding)
|
||||||
self._CompareBackpropInputFloat(input_size, filter_size, output_size,
|
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
|
||||||
stride, padding)
|
padding, "float32")
|
||||||
# double datatype is currently not supported for convolution ops
|
# double datatype is currently not supported for convolution ops
|
||||||
# on the ROCm platform
|
# on the ROCm platform
|
||||||
if test.is_built_with_rocm():
|
if test.is_built_with_rocm():
|
||||||
continue
|
continue
|
||||||
self._CompareBackpropInputDouble(input_size, filter_size, output_size,
|
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
|
||||||
stride, padding)
|
padding, "float64")
|
||||||
|
|
||||||
def _CompareBackpropFilterFloat(self, input_sizes, filter_sizes, output_sizes,
|
def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
|
||||||
stride, padding):
|
stride, padding, dtype):
|
||||||
x0 = np.random.rand(*input_sizes).astype(np.float32)
|
x0 = np.random.rand(*input_sizes).astype(dtype)
|
||||||
x2 = np.random.rand(*output_sizes).astype(np.float32)
|
x2 = np.random.rand(*output_sizes).astype(dtype)
|
||||||
|
|
||||||
def _GetVal(use_gpu):
|
|
||||||
with self.cached_session(use_gpu=use_gpu):
|
|
||||||
t0 = constant_op.constant(x0, shape=input_sizes)
|
|
||||||
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
|
|
||||||
t2 = constant_op.constant(x2, shape=output_sizes)
|
|
||||||
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
|
|
||||||
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
|
|
||||||
ret = self.evaluate(backprop)
|
|
||||||
self.assertShapeEqual(ret, backprop)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
gpu_value = _GetVal(use_gpu=True)
|
|
||||||
cpu_value = _GetVal(use_gpu=False)
|
|
||||||
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
|
|
||||||
|
|
||||||
def _CompareBackpropFilterDouble(self, input_sizes, filter_sizes,
|
|
||||||
output_sizes, stride, padding):
|
|
||||||
x0 = np.random.rand(*input_sizes).astype(np.float64)
|
|
||||||
x2 = np.random.rand(*output_sizes).astype(np.float64)
|
|
||||||
|
|
||||||
def _GetVal(use_gpu):
|
def _GetVal(use_gpu):
|
||||||
with self.cached_session(use_gpu=use_gpu):
|
with self.cached_session(use_gpu=use_gpu):
|
||||||
@ -719,19 +805,21 @@ class DepthwiseConv2DTest(test.TestCase):
|
|||||||
|
|
||||||
def testDepthwiseConv2DFilterGradCompare(self):
|
def testDepthwiseConv2DFilterGradCompare(self):
|
||||||
for index, (input_size, filter_size, output_size, stride,
|
for index, (input_size, filter_size, output_size, stride,
|
||||||
padding) in enumerate(ConfigsToTest()):
|
padding, dilations) in enumerate(ConfigsToTest()):
|
||||||
|
if dilations:
|
||||||
|
continue
|
||||||
tf_logging.info(
|
tf_logging.info(
|
||||||
"Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
|
"Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
|
||||||
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
"stride: %d, padding: %s", index, input_size, filter_size, stride,
|
||||||
padding)
|
padding)
|
||||||
self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
|
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
|
||||||
stride, padding)
|
padding, "float32")
|
||||||
# double datatype is currently not supported for convolution ops
|
# double datatype is currently not supported for convolution ops
|
||||||
# on the ROCm platform
|
# on the ROCm platform
|
||||||
if test.is_built_with_rocm():
|
if test.is_built_with_rocm():
|
||||||
continue
|
continue
|
||||||
self._CompareBackpropFilterDouble(input_size, filter_size, output_size,
|
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
|
||||||
stride, padding)
|
padding, "float64")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user