Improve depthwise_conv_op_test.py.

Before, most the forward pass tests didn't actually test anything. It would compare the outputs of depthwise_conv2d and depthwise_conv2d_native to make sure they were the same. However, depthwise_conv2d simply forwards to depthwise_conv2d_native, adding dilations to the filter. However, no test used dilations, making depthwise_conv2d and depthwise_conv2d_native equivalent.

I changed the test to add a Numpy implementation of depthwise convolutions, and compare depthwise_conv2d to the Numpy version. I additionally added dilations tests.

I also made some refactors to make the code clearer.

PiperOrigin-RevId: 300607077
Change-Id: Ib4fcdee02f633a62b884d81dd5772ecd42e2258f
This commit is contained in:
Reed Wanderman-Milne 2020-03-12 12:58:34 -07:00 committed by TensorFlower Gardener
parent 46b7331fb1
commit 3511dbe928

View File

@ -34,31 +34,136 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def _DepthwiseConv2dNumpyBasic(x1, x2, strides):
"""Compute depthwise_conv2d using Numpy.
This allows use to test TensorFlow's depthwise_conv2d by comparing to the
Numpy version.
Args:
x1: The input Numpy array, in NHWC format.
x2: The filter Numpy array.
strides: A Python list of 4 elements representing the strides.
Returns:
The depthwise conv2d output as a Numpy array.
"""
n, h, w, c = x1.shape
fh, fw, c2, o = x2.shape
assert c == c2
_, sh, sw, _ = strides
out_rows = (h - fh + sh) // sh
out_cols = (w - fw + sw) // sw
out = np.zeros([n, out_rows, out_cols, c * o])
for i in range(out_rows):
for j in range(out_cols):
for k in range(c):
start_height = i * sh
end_height = start_height + fh
start_width = j * sw
end_width = start_width + fw
# multiplied_slice.shape: (b, fh, fw, o)
multiplied_slice = (
x1[:, start_height:end_height, start_width:end_width, k, np.newaxis]
* x2[:, :, k, :])
# Set a slice of b * o elements of 'out'.
out[:, i, j, k * o:(k + 1) * o] = np.sum(multiplied_slice, axis=(1, 2))
return out
def _DepthwiseConv2dNumpy(x1, x2, strides, padding, data_format, dilations):
"""Compute depthwise_conv2d using Numpy.
This allows use to test TensorFlow's depthwise_conv2d by comparing to the
Numpy version.
Unlike `_DepthwiseConv2dNumpyBasic`, this supports more advanced features
like padding.
Args:
x1: The input Numpy array.
x2: The filter Numpy array.
strides: A Python list of 4 elements representing the strides.
padding: The padding. "SAME", "VALID", or a list of explicit paddings.
data_format: "NHWC" or "NCHW".
dilations: A list of 2 elements, representing the dilations.
Returns:
The depthwise conv2d as a Numpy array.
"""
if data_format == "NCHW":
# Transpose arguments to NHWC format.
x1 = np.transpose(x1, (0, 3, 1, 2))
strides = [strides[0], strides[3], strides[1], strides[2]]
if dilations:
dilations = [dilations[0], dilations[3], dilations[1], dilations[2]]
if dilations:
# Dilate the filter so _DepthwiseConv2dNumpyBasic doesn't have to deal with
# dilations.
fh, fw, c, o = x2.shape
new_fh = (fh - 1) * dilations[0] + 1
new_fw = (fw - 1) * dilations[1] + 1
new_x2 = np.zeros((new_fh, new_fw, c, o))
for i in range(fh):
for j in range(fw):
new_x2[i * dilations[0], j * dilations[1], : :] = x2[i, j, :, :]
x2 = new_x2
# Pad input so _DepthwiseConv2dNumpyBasic doesn't have to deal with padding.
if padding == "SAME":
def PaddingsForDim(input_dim, filter_dim, stride):
"""Computes paddings for a single dimension."""
if input_dim % stride == 0:
total_padding = max(filter_dim - stride, 0)
else:
total_padding = max(filter_dim - (input_dim % stride), 0)
pad_before = total_padding // 2
pad_after = total_padding - pad_before
return pad_before, pad_after
padding = [(0, 0),
PaddingsForDim(x1.shape[1], x2.shape[0], strides[1]),
PaddingsForDim(x1.shape[2], x2.shape[1], strides[2]),
(0, 0)]
elif padding == "VALID":
padding = [(0, 0)] * 4
x1 = np.pad(x1, padding, "constant")
y = _DepthwiseConv2dNumpyBasic(x1, x2, strides)
if data_format == "NCHW":
# Transpose back to NCHW format.
y = np.transpose(y, (0, 2, 3, 1))
return y
def ConfigsToTest():
"""Iterator for different convolution shapes, strides and paddings.
Yields:
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
convolution parameters.
Returns:
List of tuples (input_size, filter_size, out_size, stride, padding,
dilations), the depthwise convolution parameters.
"""
input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
[4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
[3, 299, 299, 3], [5, 183, 183, 1]]
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
[3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
8], [5, 5, 1, 2]]
out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
[4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
[3, 150, 150, 24], [5, 92, 92, 2]]
strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p
def Config(input_size, filter_size, out_size, stride=1, padding="SAME",
dilations=None):
return input_size, filter_size, out_size, stride, padding, dilations
return [
Config([4, 5, 5, 48], [1, 1, 48, 2], [4, 5, 5, 96]),
Config([4, 8, 8, 84], [1, 3, 84, 1], [4, 8, 8, 84]),
Config([4, 17, 17, 48], [3, 1, 48, 4], [4, 17, 17, 192]),
Config([4, 9, 27, 8], [3, 3, 8, 1], [4, 9, 27, 8]),
Config([4, 31, 31, 7], [3, 3, 7, 1], [4, 31, 31, 7]),
Config([4, 35, 35, 2], [5, 5, 2, 1], [4, 35, 35, 2]),
Config([4, 147, 147, 2], [3, 3, 2, 8], [4, 49, 49, 16], 3,
padding="VALID"),
Config([3, 299, 299, 3], [3, 2, 3, 8], [3, 150, 150, 24], 2),
Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 92, 92, 2], 2),
Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 183, 183, 2],
dilations=[2, 2]),
Config([5, 41, 35, 2], [4, 7, 2, 2], [5, 32, 23, 4], padding="VALID",
dilations=[3, 2]),
]
def CheckGradConfigsToTest():
@ -67,34 +172,26 @@ def CheckGradConfigsToTest():
compute_gradient_error() is very expensive. So the configs should be
relatively small.
Yields:
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
convolution parameters.
Returns:
List of tuples (input_size, filter_size, out_size, stride, padding,
dilations), the depthwise convolution parameters.
"""
input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2],
[2, 15, 16, 1]]
filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1],
[3, 3, 1, 2]]
out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2],
[2, 5, 5, 2]]
strides = [1, 2, 1, 1, 3]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, VALID, SAME, SAME, VALID]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p
def Config(input_size, filter_size, out_size, stride=1, padding="SAME",
dilations=None):
return input_size, filter_size, out_size, stride, padding, dilations
return [
Config([2, 5, 8, 1], [4, 4, 1, 2], [2, 5, 8, 2]),
Config([4, 5, 5, 1], [2, 2, 1, 2], [4, 2, 2, 2], 2, padding="VALID"),
Config([2, 4, 4, 2], [3, 1, 2, 2], [2, 4, 4, 4]),
Config([1, 15, 15, 2], [1, 3, 2, 1], [1, 15, 15, 2]),
Config([2, 15, 16, 1], [3, 3, 1, 2], [2, 5, 5, 2], 3, padding="VALID"),
Config([2, 5, 8, 1], [4, 3, 1, 2], [2, 5, 8, 2], dilations=[1, 2]),
]
class DepthwiseConv2DTest(test.TestCase):
# This is testing that depthwise_conv2d and depthwise_conv2d_native
# produce the same results. It also tests that NCHW and NHWC
# formats agree, by comparing the depthwise_conv2d_native with
# 'NCHW' format (with transposition) matches the 'NHWC' format using
# the higher level interface.
# This tests depthwise_conv2d and depthwise_conv2d_native
def _VerifyValues(self,
tensor_in_sizes,
filter_in_sizes,
@ -103,7 +200,8 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
use_gpu,
grouped_conv=False,
data_format="NHWC"):
data_format="NHWC",
dilations=None):
"""Verifies the output values of the convolution function.
Args:
@ -117,6 +215,7 @@ class DepthwiseConv2DTest(test.TestCase):
use_gpu: Whether to use GPU.
grouped_conv: Whether to use cuDNN 7's grouped convolution.
data_format: The data_format of the input. "NHWC" or "NCHW".
dilations: A list of 2 elements, representing the dilations.
"""
input_size = 1
filter_size = 1
@ -126,7 +225,14 @@ class DepthwiseConv2DTest(test.TestCase):
filter_size *= s
# Initializes the input and filter tensor with numbers incrementing from 1.
x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
x1 = np.array(x1).reshape(tensor_in_sizes)
x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
x2 = np.array(x2).reshape(filter_in_sizes)
# Compute reference result
strides = [1, stride, stride, 1]
np_result = _DepthwiseConv2dNumpy(x1, x2, strides, padding, "NHWC",
dilations)
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.session(graph=graph, use_gpu=use_gpu) as sess:
@ -137,60 +243,62 @@ class DepthwiseConv2DTest(test.TestCase):
}[data_type]
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
t1.set_shape(tensor_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=data_type)
native_t1 = t1
strides = [1, stride, stride, 1]
if data_format == "NCHW":
# Transpose from NHWC input to NCHW
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
with sess.graph._kernel_label_map({
"DepthwiseConv2dNative": "cudnn_grouped_convolution"
} if grouped_conv else {}):
conv_native = nn_ops.depthwise_conv2d_native(
native_t1,
t2,
strides=strides,
data_format=data_format,
padding=padding)
# depthwise_conv2d_native does not support dilations except on TPUs.
if dilations is None:
with sess.graph._kernel_label_map({
"DepthwiseConv2dNative": "cudnn_grouped_convolution"
} if grouped_conv else {}):
conv_native = nn_ops.depthwise_conv2d_native(
t1,
t2,
strides=strides,
data_format=data_format,
padding=padding)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
try:
native_result = self.evaluate(conv_native)
except errors.InvalidArgumentError as e:
# Grouped convolution kernel is only registered for cuDNN 7. Silently
# return when we are running on an earlier version or without GPU.
if e.message.startswith(
"No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
tf_logging.warn("Skipping grouped convolution test")
return
raise e
try:
# The Numpy array from calling depthwise_conv2d_native
native_result = self.evaluate(conv_native)
except errors.InvalidArgumentError as e:
# Grouped convolution kernel is only registered for cuDNN 7. Silently
# return when we are running on an earlier version or without GPU.
if e.message.startswith(
"No OpKernel was registered to support Op "
"'DepthwiseConv2dNative'"):
tf_logging.warn("Skipping grouped convolution test")
return
raise e
conv_interface = nn_impl.depthwise_conv2d(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
t1, t2, strides=strides, padding=padding,
data_format=data_format, dilations=dilations)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_interface = array_ops.transpose(conv_interface, [0, 2, 3, 1])
# The Numpy array from calling depthwise_conv2d
interface_result = self.evaluate(conv_interface)
tf_logging.info(
"data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f",
data_type, use_gpu, grouped_conv,
np.amax(np.absolute(native_result - interface_result)))
self.assertArrayNear(
np.ravel(native_result), np.ravel(interface_result), tolerance)
self.assertShapeEqual(native_result, conv_native)
self.assertShapeEqual(native_result, conv_interface)
if dilations is None:
self.assertAllClose(native_result, np_result, atol=tolerance, rtol=0.)
self.assertAllClose(interface_result, np_result, atol=tolerance, rtol=0.)
@test_util.run_v1_only("b/120545219")
@test_util.run_cuda_only
def testDepthwiseConv2DCudnn(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
padding, dilations) in enumerate(ConfigsToTest()):
# The CuDNN depthwise conv is turned on only when input/output is NCHW and
# float16(half). See cudnn release note 7.6.3.
tf_logging.info(
@ -204,12 +312,13 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
use_gpu=True,
data_format="NCHW")
data_format="NCHW",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2D(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
padding, dilations) in enumerate(ConfigsToTest()):
tf_logging.info(
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
"%s", index, input_size, filter_size, stride, padding)
@ -219,7 +328,8 @@ class DepthwiseConv2DTest(test.TestCase):
for data_type in ([dtypes.float32] + optional_float64):
tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
input_size, filter_size, stride, padding, data_type, use_gpu=True,
dilations=dilations)
tf_logging.info("Testing with grouped_conv")
self._VerifyValues(
input_size,
@ -228,7 +338,8 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
use_gpu=True,
grouped_conv=True)
grouped_conv=True,
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DWithUnknownShape(self):
@ -250,7 +361,7 @@ class DepthwiseConv2DTest(test.TestCase):
return
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
padding, dilations) in enumerate(ConfigsToTest()):
tf_logging.info(
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
@ -265,7 +376,8 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
use_gpu=True,
data_format="NCHW")
data_format="NCHW",
dilations=dilations)
# This is testing against hand calculated results.
@ -385,7 +497,8 @@ class DepthwiseConv2DTest(test.TestCase):
test_input,
use_gpu,
grouped_conv=False,
data_format="NHWC"):
data_format="NHWC",
dilations=None):
input_size = 1
for x in input_shape:
input_size *= x
@ -393,7 +506,9 @@ class DepthwiseConv2DTest(test.TestCase):
for x in filter_shape:
filter_size *= x
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
input_np = np.array(input_data).reshape(input_shape)
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
filter_np = np.array(filter_data).reshape(filter_shape)
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.session(graph=graph, use_gpu=use_gpu) as sess:
@ -404,9 +519,9 @@ class DepthwiseConv2DTest(test.TestCase):
}[data_type]
input_tensor = constant_op.constant(
input_data, shape=input_shape, dtype=data_type, name="input")
input_np, shape=input_shape, dtype=data_type, name="input")
filter_tensor = constant_op.constant(
filter_data, shape=filter_shape, dtype=data_type, name="filter")
filter_np, shape=filter_shape, dtype=data_type, name="filter")
native_input = input_tensor
strides = [1, stride, stride, 1]
@ -427,12 +542,13 @@ class DepthwiseConv2DTest(test.TestCase):
"DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
"DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
} if grouped_conv else {}):
depthwise_conv2d = nn_ops.depthwise_conv2d_native(
depthwise_conv2d = nn_impl.depthwise_conv2d(
native_input,
filter_tensor,
strides,
padding,
data_format=data_format,
dilations=dilations,
name="depthwise_conv2d")
self.assertEqual(output_shape, depthwise_conv2d.get_shape())
@ -462,7 +578,7 @@ class DepthwiseConv2DTest(test.TestCase):
@test_util.run_cuda_only
def testDepthwiseConv2DInputGradCudnn(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
padding, dilations) in enumerate(CheckGradConfigsToTest()):
# The CuDNN depthwise conv (input gradient) is turned on only when
# stride = 1, input/output is NCHW and float16(half). See cudnn release
# note 7.6.3.
@ -482,12 +598,13 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=True,
use_gpu=True,
data_format="NCHW")
data_format="NCHW",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DInputGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
padding, dilations) in enumerate(CheckGradConfigsToTest()):
tf_logging.info(
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
@ -503,7 +620,8 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
test_input=True,
use_gpu=True)
use_gpu=True,
dilations=dilations)
self._ConstructAndTestGradient(
input_size,
filter_size,
@ -513,7 +631,8 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=True,
use_gpu=True,
grouped_conv=True)
grouped_conv=True,
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DInputGradFormat(self):
@ -521,7 +640,7 @@ class DepthwiseConv2DTest(test.TestCase):
return
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
padding, dilations) in enumerate(CheckGradConfigsToTest()):
tf_logging.info(
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
@ -539,13 +658,14 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=True,
use_gpu=True,
data_format="NCHW")
data_format="NCHW",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
@test_util.run_cuda_only
def testDepthwiseConv2DFilterGradCudnn(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
padding, dilations) in enumerate(CheckGradConfigsToTest()):
# The CuDNN depthwise conv (filter gradient) is turned on only when
# input/output is float16(half). See cudnn release note 7.6.3.
tf_logging.info(
@ -562,7 +682,8 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=False,
use_gpu=True,
data_format="NCHW")
data_format="NCHW",
dilations=dilations)
self._ConstructAndTestGradient(
input_size,
filter_size,
@ -572,12 +693,13 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=False,
use_gpu=True,
data_format="NHWC")
data_format="NHWC",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DFilterGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
padding, dilations) in enumerate(CheckGradConfigsToTest()):
tf_logging.info(
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
"%d, padding: %s", index, input_size, filter_size, stride, padding)
@ -593,7 +715,8 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
test_input=False,
use_gpu=True)
use_gpu=True,
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DFilterGradFormat(self):
@ -601,7 +724,7 @@ class DepthwiseConv2DTest(test.TestCase):
return
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
padding, dilations) in enumerate(CheckGradConfigsToTest()):
tf_logging.info(
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
@ -619,32 +742,13 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=False,
use_gpu=True,
data_format="NCHW")
data_format="NCHW",
dilations=dilations)
def _CompareBackpropInputFloat(self, input_sizes, filter_sizes, output_sizes,
stride, padding):
x1 = np.random.rand(*filter_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_gpu):
with self.cached_session(use_gpu=use_gpu):
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
t1 = constant_op.constant(x1, shape=filter_sizes)
t2 = constant_op.constant(x2, shape=output_sizes)
backprop = nn_ops.depthwise_conv2d_native_backprop_input(
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
ret = self.evaluate(backprop)
self.assertShapeEqual(ret, backprop)
return ret
gpu_value = _GetVal(use_gpu=True)
cpu_value = _GetVal(use_gpu=False)
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
def _CompareBackpropInputDouble(self, input_sizes, filter_sizes, output_sizes,
stride, padding):
x1 = np.random.rand(*filter_sizes).astype(np.float64)
x2 = np.random.rand(*output_sizes).astype(np.float64)
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
stride, padding, dtype):
x1 = np.random.rand(*filter_sizes).astype(dtype)
x2 = np.random.rand(*output_sizes).astype(dtype)
def _GetVal(use_gpu):
with self.cached_session(use_gpu=use_gpu):
@ -663,44 +767,26 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DInputGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
padding, dilations) in enumerate(ConfigsToTest()):
if dilations:
continue
tf_logging.info(
"Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
self._CompareBackpropInputFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding, "float32")
# double datatype is currently not supported for convolution ops
# on the ROCm platform
if test.is_built_with_rocm():
continue
self._CompareBackpropInputDouble(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding, "float64")
def _CompareBackpropFilterFloat(self, input_sizes, filter_sizes, output_sizes,
stride, padding):
x0 = np.random.rand(*input_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_gpu):
with self.cached_session(use_gpu=use_gpu):
t0 = constant_op.constant(x0, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = constant_op.constant(x2, shape=output_sizes)
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
ret = self.evaluate(backprop)
self.assertShapeEqual(ret, backprop)
return ret
gpu_value = _GetVal(use_gpu=True)
cpu_value = _GetVal(use_gpu=False)
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
def _CompareBackpropFilterDouble(self, input_sizes, filter_sizes,
output_sizes, stride, padding):
x0 = np.random.rand(*input_sizes).astype(np.float64)
x2 = np.random.rand(*output_sizes).astype(np.float64)
def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
stride, padding, dtype):
x0 = np.random.rand(*input_sizes).astype(dtype)
x2 = np.random.rand(*output_sizes).astype(dtype)
def _GetVal(use_gpu):
with self.cached_session(use_gpu=use_gpu):
@ -719,19 +805,21 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DFilterGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
padding, dilations) in enumerate(ConfigsToTest()):
if dilations:
continue
tf_logging.info(
"Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
padding, "float32")
# double datatype is currently not supported for convolution ops
# on the ROCm platform
if test.is_built_with_rocm():
continue
self._CompareBackpropFilterDouble(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
padding, "float64")
if __name__ == "__main__":