[TF:NN:CONVOLUTION] Don't call space to batch in depthwise convolution on TPU.

PiperOrigin-RevId: 261242540
This commit is contained in:
A. Unique TensorFlower 2019-08-01 18:46:07 -07:00 committed by TensorFlower Gardener
parent 5ee9c10e83
commit 20506ddda8
5 changed files with 369 additions and 14 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -87,6 +88,32 @@ def ConfigsToTest():
yield i, f, o, s, p yield i, f, o, s, p
def ConfigsWithDilationsToTest():
"""Iterator for different convolution shapes, strides and paddings.
Yields:
Tuple (input_size, filter_size, out_size, stride, dilation, padding), the
depthwise
convolution parameters.
"""
input_sizes = [[4, 6, 6, 48], [4, 8, 8, 84], [4, 36, 36, 2], [4, 148, 148, 2],
[3, 300, 300, 3]]
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [5, 5, 2, 1], [4, 4, 2, 8],
[2, 2, 3, 8]]
out_sizes = [[4, 6, 6, 96], [4, 8, 8, 84], [4, 36, 36, 2], [4, 74, 74, 16],
[3, 296, 296, 24]]
strides = [1, 1, 2, 2, 1]
dilations = [2, 2, 4, 2, 4]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, SAME, SAME, SAME, VALID]
for i, f, o, s, d, p in zip(input_sizes, filter_sizes, out_sizes, strides,
dilations, paddings):
yield i, f, o, s, d, p
def CheckGradConfigsToTest(): def CheckGradConfigsToTest():
"""Iterator for different convolution shapes, strides and paddings. """Iterator for different convolution shapes, strides and paddings.
@ -315,6 +342,118 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
padding="VALID", padding="VALID",
expected=expected_output) expected=expected_output)
# This is testing that depthwise_conv2d with dilation produces
# the same results between CPU and TPU. It also tests that NCHW
# and NWHC formats agree.
def _VerifyValuesWithDilation(self,
tensor_in_sizes,
filter_in_sizes,
stride,
dilation,
padding,
data_type,
data_format="NHWC"):
"""Verifies the output values of the convolution function.
Args:
tensor_in_sizes: Input tensor dimensions in [batch, input_rows,
input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols,
input_depth, depth_multiplier].
stride: Stride.
dilation: Dilation.
padding: Padding type.
data_type: The data type to use.
data_format: The data_format of the input. "NHWC" or "NCHW".
"""
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input and filter tensor with numbers incrementing from 1.
x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
dtype=data_type).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=data_type).reshape(filter_in_sizes)
with self.session() as sess:
if data_type == np.float32:
# TODO(b/64210055): Tolerance for TPU is high.
tolerance = 1e-2
else:
self.assertEqual(data_type, np.float64)
tolerance = 1e-8
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type)
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type)
native_t1 = t1
strides = [1, stride, stride, 1]
dilations = [dilation, dilation]
if data_format == "NCHW":
# Transpose from NWHC input to NCHW
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
with self.test_scope():
conv_native = nn_impl.depthwise_conv2d(
native_t1,
t2,
strides=strides,
rate=dilations,
data_format=data_format,
padding=padding)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
with ops.device("CPU"):
# CPU only support NHWC format
strides = [1, stride, stride, 1]
conv_interface = nn_impl.depthwise_conv2d(
t1, t2, strides=strides, rate=dilations, padding=padding)
native_result = sess.run(conv_native, {t1: x1, t2: x2})
interface_result = sess.run(conv_interface, {t1: x1, t2: x2})
print("data_type:", data_type, "max diff = ",
np.amax(np.absolute(native_result - interface_result)))
self.assertAllClose(
np.ravel(native_result), np.ravel(interface_result), rtol=tolerance)
def testDilationDepthwiseConv2DWith(self):
for index, (input_size, filter_size, _, stride, dilation,
padding) in enumerate(ConfigsWithDilationsToTest()):
print("Testing DilationDepthwiseConv2D,", index, "th config:", input_size,
"*", filter_size, "stride:", stride, "dilation: ", dilation,
"padding:", padding)
for data_type in self.float_types:
# TODO(phawkins): the reference implementation only supports float32.
if data_type == np.float32:
self._VerifyValuesWithDilation(input_size, filter_size, stride,
dilation, padding, data_type)
def testDilationDepthwiseConv2DWithFormat(self):
for index, (input_size, filter_size, _, stride, dilation,
padding) in enumerate(ConfigsWithDilationsToTest()):
print("Testing DilationDepthwiseConv2DFormat,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "dilation:",
dilation, "padding:", padding)
for data_type in self.float_types:
# TODO(phawkins): the reference implementation only supports float32.
if data_type == np.float32:
self._VerifyValuesWithDilation(
input_size,
filter_size,
stride,
dilation,
padding,
data_type,
data_format="NCHW")
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes, def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
stride, padding): stride, padding):
x1 = np.random.rand(*filter_sizes).astype(np.float32) x1 = np.random.rand(*filter_sizes).astype(np.float32)
@ -420,5 +559,139 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
padding, padding,
data_format="NCHW") data_format="NCHW")
def _CompareBackpropInputWithDilation(self, input_sizes, filter_sizes,
output_sizes, stride, dilation,
padding):
x1 = np.random.rand(*filter_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
with self.session():
t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
if use_xla:
with self.test_scope():
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
backprop = nn_ops.depthwise_conv2d_native_backprop_input(
t0,
t1,
t2,
strides=[1, stride, stride, 1],
dilations=[1, dilation, dilation, 1],
padding=padding)
else:
# TODO(wangtao): figure out gradient with stride > 1.
# depthwise_conv2d_native_backprop_input on CPU doesn't support
# dilation.
t3 = array_ops.space_to_batch(
t2, block_size=dilation, paddings=[[0, 0], [0, 0]])
input_sizes_transform = [
input_sizes[0] * dilation * dilation, input_sizes[1] // dilation,
input_sizes[2] // dilation, input_sizes[3]
]
t0 = constant_op.constant(
input_sizes_transform, shape=[len(input_sizes)])
backprop_naive = nn_ops.depthwise_conv2d_native_backprop_input(
t0, t1, t3, strides=[1, stride, stride, 1], padding=padding)
backprop = array_ops.batch_to_space(
backprop_naive, [[0, 0], [0, 0]], block_size=dilation)
ret = backprop.eval({t1: x1, t2: x2})
self.assertShapeEqual(ret, backprop)
return ret
gpu_value = _GetVal(use_xla=True)
cpu_value = _GetVal(use_xla=False)
# TODO (b/64210055): Tolerance for TPU is high.
self.assertAllClose(cpu_value, gpu_value, rtol=1e-2, atol=1e-3)
def testDilationDepthwiseConv2DInputGradWithCompare(self):
for index, (input_size, filter_size, output_size, stride, dilation,
padding) in enumerate(ConfigsWithDilationsToTest()):
print("Testing DilationDepthwiseConv2DInputGradWithDilationCompare,",
index, "th config:", input_size, "*", filter_size, "stride:",
stride, "dilation:", dilation, "padding:", padding)
# TODO(wangtao): implement CPU grad computation with stride > 1.
if stride == 1:
self._CompareBackpropInputWithDilation(input_size, filter_size,
output_size, stride, dilation,
padding)
def _CompareBackpropFilterWithDilation(self,
input_sizes,
filter_sizes,
output_sizes,
stride,
dilation,
padding,
data_format="NHWC"):
x0 = np.random.rand(*input_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
with self.session():
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
native_t0 = t0
native_t2 = t2
strides = [1, stride, stride, 1]
dilations = [1, dilation, dilation, 1]
if use_xla:
if data_format == "NCHW":
# Transpose from NWHC input to NCHW
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
native_t0 = array_ops.transpose(t0, [0, 3, 1, 2])
native_t2 = array_ops.transpose(t2, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
dilations = [1, 1, dilation, dilation]
with self.test_scope():
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
native_t0,
t1,
native_t2,
strides=strides,
padding=padding,
dilations=dilations,
data_format=data_format)
else:
# For CPU, the format NCHW is not supported. Therefore we always use
# NHWC here.
# depthwise_conv2d_native_backprop_filter on CPU doesn't support
# dilation.
native_t3 = array_ops.space_to_batch(
native_t2, block_size=dilation, paddings=[[0, 0], [0, 0]])
native_t0_transform = array_ops.space_to_batch(
native_t0, block_size=dilation, paddings=[[0, 0], [0, 0]])
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
native_t0_transform,
t1,
native_t3,
strides=strides,
padding=padding)
ret = backprop.eval({t0: x0, t2: x2})
self.assertShapeEqual(ret, backprop)
return ret
gpu_value = _GetVal(use_xla=True)
cpu_value = _GetVal(use_xla=False)
# TODO(b/64210055): Tolerance for TPU is high.
self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-4)
def testDilationDepthwiseConv2DFilterGradCompare(self):
for index, (input_size, filter_size, output_size, stride, dilation,
padding) in enumerate(ConfigsWithDilationsToTest()):
print("Testing DilationDepthwiseConv2DFilterGradCompare,", index,
"th config:", input_size, "*", filter_size, "producing output",
output_size, "stride:", stride, "dilation:", dilation, "padding:",
padding)
if stride == 1:
# TODO(wangtao): implement CPU grad computation with stride > 1.
self._CompareBackpropFilterWithDilation(input_size, filter_size,
output_size, stride, dilation,
padding)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -408,11 +408,15 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
errors::InvalidArgument( errors::InvalidArgument(
"Current implementation does not yet support " "Current implementation does not yet support "
"dilations in the batch and depth dimensions.")); "dilations in the batch and depth dimensions."));
// TODO(yangzihao): Add a CPU implementation for dilated convolution. if (std::is_same<Device, CPUDevice>::value ||
OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), std::is_same<Device, GPUDevice>::value) {
errors::InvalidArgument( // TODO(yangzihao): Add a CPU implementation for dilated convolution.
"Current libxsmm and customized CPU implementations do " OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
"not yet support dilation rates larger than 1.")); errors::InvalidArgument(
"Current libxsmm and customized CPU implementations do "
"not yet support dilation rates larger than 1."));
dilations_ = {1, 1, 1, 1};
}
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -434,8 +438,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
context, context,
ConvBackpropComputeDimensionsV2( ConvBackpropComputeDimensionsV2(
"Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(), "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1}, filter_shape, out_backprop.shape(), dilations_, strides_, padding_,
strides_, padding_, explicit_paddings_, data_format_, &dims)); explicit_paddings_, data_format_, &dims));
Tensor* filter_backprop; Tensor* filter_backprop;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,

View File

@ -628,15 +628,17 @@ def _DepthwiseConv2dNativeGrad(op, grad):
array_ops.shape(op.inputs[0]), array_ops.shape(op.inputs[0]),
op.inputs[1], op.inputs[1],
grad, grad,
op.get_attr("strides"), dilations=op.get_attr("dilations"),
op.get_attr("padding"), strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format")), data_format=op.get_attr("data_format")),
nn_ops.depthwise_conv2d_native_backprop_filter( nn_ops.depthwise_conv2d_native_backprop_filter(
op.inputs[0], op.inputs[0],
array_ops.shape(op.inputs[1]), array_ops.shape(op.inputs[1]),
grad, grad,
op.get_attr("strides"), dilations=op.get_attr("dilations"),
op.get_attr("padding"), strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
data_format=op.get_attr("data_format")) data_format=op.get_attr("data_format"))
] ]

View File

@ -715,6 +715,22 @@ def zero_fraction(value, name=None):
return array_ops.identity(zero_fraction_float32, "fraction") return array_ops.identity(zero_fraction_float32, "fraction")
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
def _enclosing_tpu_context():
# pylint: disable=protected-access
context = ops.get_default_graph()._get_control_flow_context()
# pylint: enable=protected-access
while context is not None and not isinstance(
context, control_flow_ops.XLAControlFlowContext):
context = context.outer_context
return context
# copybara:strip_end
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export(v1=["nn.depthwise_conv2d"]) @tf_export(v1=["nn.depthwise_conv2d"])
def depthwise_conv2d(input, def depthwise_conv2d(input,
@ -774,6 +790,25 @@ def depthwise_conv2d(input,
if rate is None: if rate is None:
rate = [1, 1] rate = [1, 1]
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
# Use depthwise_conv2d_native if executing on TPU.
if _enclosing_tpu_context() is not None:
if data_format == "NCHW":
dilations = [1, 1, rate[0], rate[1]]
else:
dilations = [1, rate[0], rate[1], 1]
return nn_ops.depthwise_conv2d_native(
input=input,
filter=filter,
strides=strides,
padding=padding,
data_format=data_format,
dilations=dilations,
name=name)
# copybara:strip_end
def op(input_converted, _, padding): def op(input_converted, _, padding):
return nn_ops.depthwise_conv2d_native( return nn_ops.depthwise_conv2d_native(
input=input_converted, input=input_converted,

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
@ -919,6 +920,22 @@ convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
"filter", "filters") "filter", "filters")
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
def _enclosing_tpu_context():
# pylint: disable=protected-access
run_context = ops.get_default_graph()._get_control_flow_context()
# pylint: enable=protected-access
while run_context is not None and not isinstance(
run_context, control_flow_ops.XLAControlFlowContext):
run_context = run_context.outer_context
return run_context
# copybara:strip_end
def convolution_internal( def convolution_internal(
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
filters, filters,
@ -958,8 +975,14 @@ def convolution_internal(
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d} conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
if all(i == 1 for i in dilations): # copybara:strip_begin
# fast path if no dilation as gradient only supported on GPU for dilations # TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations):
# fast path for TPU or if no dilation as gradient only supported on GPU
# for dilations
# copybara:strip_end
# copybara:insert if all(i == 1 for i in dilations):
op = conv_ops[n] op = conv_ops[n]
return op( return op(
input, input,
@ -1056,7 +1079,9 @@ class Convolution(object):
self.filter_shape = filter_shape self.filter_shape = filter_shape
self.data_format = data_format self.data_format = data_format
self.strides = strides self.strides = strides
self.padding = padding
self.name = name self.name = name
self.dilation_rate = dilation_rate
self.conv_op = _WithSpaceToBatch( self.conv_op = _WithSpaceToBatch(
input_shape, input_shape,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
@ -1076,7 +1101,23 @@ class Convolution(object):
name=self.name) name=self.name)
def __call__(self, inp, filter): # pylint: disable=redefined-builtin def __call__(self, inp, filter): # pylint: disable=redefined-builtin
return self.conv_op(inp, filter) # copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
# TPU convolution supports dilations greater than 1.
if _enclosing_tpu_context() is not None:
return convolution_internal(
inp,
filter,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilations=self.dilation_rate,
name=self.name)
else:
return self.conv_op(inp, filter)
# copybara:strip_end
# copybara:insert return self.conv_op(inp, filter)
@tf_export(v1=["nn.pool"]) @tf_export(v1=["nn.pool"])