From 20506ddda860b79ff4a5e00fdcb0242f8498f60c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Aug 2019 18:46:07 -0700 Subject: [PATCH] [TF:NN:CONVOLUTION] Don't call space to batch in depthwise convolution on TPU. PiperOrigin-RevId: 261242540 --- .../compiler/tests/depthwise_conv_op_test.py | 273 ++++++++++++++++++ .../core/kernels/conv_grad_filter_ops.cc | 18 +- tensorflow/python/ops/nn_grad.py | 10 +- tensorflow/python/ops/nn_impl.py | 35 +++ tensorflow/python/ops/nn_ops.py | 47 ++- 5 files changed, 369 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index c55bc23cf47..a49985f0446 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -87,6 +88,32 @@ def ConfigsToTest(): yield i, f, o, s, p +def ConfigsWithDilationsToTest(): + """Iterator for different convolution shapes, strides and paddings. + + Yields: + Tuple (input_size, filter_size, out_size, stride, dilation, padding), the + depthwise + convolution parameters. + """ + input_sizes = [[4, 6, 6, 48], [4, 8, 8, 84], [4, 36, 36, 2], [4, 148, 148, 2], + [3, 300, 300, 3]] + filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [5, 5, 2, 1], [4, 4, 2, 8], + [2, 2, 3, 8]] + out_sizes = [[4, 6, 6, 96], [4, 8, 8, 84], [4, 36, 36, 2], [4, 74, 74, 16], + [3, 296, 296, 24]] + strides = [1, 1, 2, 2, 1] + dilations = [2, 2, 4, 2, 4] + # pylint: disable=invalid-name + VALID = "VALID" + SAME = "SAME" + # pylint: enable=invalid-name + paddings = [SAME, SAME, SAME, SAME, VALID] + for i, f, o, s, d, p in zip(input_sizes, filter_sizes, out_sizes, strides, + dilations, paddings): + yield i, f, o, s, d, p + + def CheckGradConfigsToTest(): """Iterator for different convolution shapes, strides and paddings. @@ -315,6 +342,118 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) + # This is testing that depthwise_conv2d with dilation produces + # the same results between CPU and TPU. It also tests that NCHW + # and NWHC formats agree. + def _VerifyValuesWithDilation(self, + tensor_in_sizes, + filter_in_sizes, + stride, + dilation, + padding, + data_type, + data_format="NHWC"): + """Verifies the output values of the convolution function. + + Args: + tensor_in_sizes: Input tensor dimensions in [batch, input_rows, + input_cols, input_depth]. + filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols, + input_depth, depth_multiplier]. + stride: Stride. + dilation: Dilation. + padding: Padding type. + data_type: The data type to use. + data_format: The data_format of the input. "NHWC" or "NCHW". + """ + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + # Initializes the input and filter tensor with numbers incrementing from 1. + x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)], + dtype=data_type).reshape(tensor_in_sizes) + x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], + dtype=data_type).reshape(filter_in_sizes) + with self.session() as sess: + if data_type == np.float32: + # TODO(b/64210055): Tolerance for TPU is high. + tolerance = 1e-2 + else: + self.assertEqual(data_type, np.float64) + tolerance = 1e-8 + + t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type) + t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type) + + native_t1 = t1 + strides = [1, stride, stride, 1] + dilations = [dilation, dilation] + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] + + with self.test_scope(): + conv_native = nn_impl.depthwise_conv2d( + native_t1, + t2, + strides=strides, + rate=dilations, + data_format=data_format, + padding=padding) + + if data_format == "NCHW": + # Transpose back from NCHW to NHWC + conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1]) + + with ops.device("CPU"): + # CPU only support NHWC format + strides = [1, stride, stride, 1] + conv_interface = nn_impl.depthwise_conv2d( + t1, t2, strides=strides, rate=dilations, padding=padding) + + native_result = sess.run(conv_native, {t1: x1, t2: x2}) + interface_result = sess.run(conv_interface, {t1: x1, t2: x2}) + + print("data_type:", data_type, "max diff = ", + np.amax(np.absolute(native_result - interface_result))) + self.assertAllClose( + np.ravel(native_result), np.ravel(interface_result), rtol=tolerance) + + def testDilationDepthwiseConv2DWith(self): + for index, (input_size, filter_size, _, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2D,", index, "th config:", input_size, + "*", filter_size, "stride:", stride, "dilation: ", dilation, + "padding:", padding) + for data_type in self.float_types: + # TODO(phawkins): the reference implementation only supports float32. + if data_type == np.float32: + self._VerifyValuesWithDilation(input_size, filter_size, stride, + dilation, padding, data_type) + + def testDilationDepthwiseConv2DWithFormat(self): + for index, (input_size, filter_size, _, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2DFormat,", index, "th config:", + input_size, "*", filter_size, "stride:", stride, "dilation:", + dilation, "padding:", padding) + for data_type in self.float_types: + # TODO(phawkins): the reference implementation only supports float32. + if data_type == np.float32: + self._VerifyValuesWithDilation( + input_size, + filter_size, + stride, + dilation, + padding, + data_type, + data_format="NCHW") + def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes, stride, padding): x1 = np.random.rand(*filter_sizes).astype(np.float32) @@ -420,5 +559,139 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): padding, data_format="NCHW") + def _CompareBackpropInputWithDilation(self, input_sizes, filter_sizes, + output_sizes, stride, dilation, + padding): + x1 = np.random.rand(*filter_sizes).astype(np.float32) + x2 = np.random.rand(*output_sizes).astype(np.float32) + + def _GetVal(use_xla): + with self.session(): + t1 = array_ops.placeholder(np.float32, shape=filter_sizes) + t2 = array_ops.placeholder(np.float32, shape=output_sizes) + if use_xla: + with self.test_scope(): + t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) + backprop = nn_ops.depthwise_conv2d_native_backprop_input( + t0, + t1, + t2, + strides=[1, stride, stride, 1], + dilations=[1, dilation, dilation, 1], + padding=padding) + else: + # TODO(wangtao): figure out gradient with stride > 1. + # depthwise_conv2d_native_backprop_input on CPU doesn't support + # dilation. + t3 = array_ops.space_to_batch( + t2, block_size=dilation, paddings=[[0, 0], [0, 0]]) + input_sizes_transform = [ + input_sizes[0] * dilation * dilation, input_sizes[1] // dilation, + input_sizes[2] // dilation, input_sizes[3] + ] + t0 = constant_op.constant( + input_sizes_transform, shape=[len(input_sizes)]) + backprop_naive = nn_ops.depthwise_conv2d_native_backprop_input( + t0, t1, t3, strides=[1, stride, stride, 1], padding=padding) + backprop = array_ops.batch_to_space( + backprop_naive, [[0, 0], [0, 0]], block_size=dilation) + + ret = backprop.eval({t1: x1, t2: x2}) + self.assertShapeEqual(ret, backprop) + return ret + + gpu_value = _GetVal(use_xla=True) + cpu_value = _GetVal(use_xla=False) + + # TODO (b/64210055): Tolerance for TPU is high. + self.assertAllClose(cpu_value, gpu_value, rtol=1e-2, atol=1e-3) + + def testDilationDepthwiseConv2DInputGradWithCompare(self): + for index, (input_size, filter_size, output_size, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2DInputGradWithDilationCompare,", + index, "th config:", input_size, "*", filter_size, "stride:", + stride, "dilation:", dilation, "padding:", padding) + # TODO(wangtao): implement CPU grad computation with stride > 1. + if stride == 1: + self._CompareBackpropInputWithDilation(input_size, filter_size, + output_size, stride, dilation, + padding) + + def _CompareBackpropFilterWithDilation(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + dilation, + padding, + data_format="NHWC"): + x0 = np.random.rand(*input_sizes).astype(np.float32) + x2 = np.random.rand(*output_sizes).astype(np.float32) + + def _GetVal(use_xla): + with self.session(): + t0 = array_ops.placeholder(np.float32, shape=input_sizes) + t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) + t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + dilations = [1, dilation, dilation, 1] + + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] + dilations = [1, 1, dilation, dilation] + with self.test_scope(): + backprop = nn_ops.depthwise_conv2d_native_backprop_filter( + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + dilations=dilations, + data_format=data_format) + else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. + # depthwise_conv2d_native_backprop_filter on CPU doesn't support + # dilation. + native_t3 = array_ops.space_to_batch( + native_t2, block_size=dilation, paddings=[[0, 0], [0, 0]]) + native_t0_transform = array_ops.space_to_batch( + native_t0, block_size=dilation, paddings=[[0, 0], [0, 0]]) + backprop = nn_ops.depthwise_conv2d_native_backprop_filter( + native_t0_transform, + t1, + native_t3, + strides=strides, + padding=padding) + ret = backprop.eval({t0: x0, t2: x2}) + self.assertShapeEqual(ret, backprop) + return ret + + gpu_value = _GetVal(use_xla=True) + cpu_value = _GetVal(use_xla=False) + # TODO(b/64210055): Tolerance for TPU is high. + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-4) + + def testDilationDepthwiseConv2DFilterGradCompare(self): + for index, (input_size, filter_size, output_size, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2DFilterGradCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "dilation:", dilation, "padding:", + padding) + if stride == 1: + # TODO(wangtao): implement CPU grad computation with stride > 1. + self._CompareBackpropFilterWithDilation(input_size, filter_size, + output_size, stride, dilation, + padding) + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index ea934b81680..9d5f316ff6f 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -408,11 +408,15 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { errors::InvalidArgument( "Current implementation does not yet support " "dilations in the batch and depth dimensions.")); - // TODO(yangzihao): Add a CPU implementation for dilated convolution. - OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), - errors::InvalidArgument( - "Current libxsmm and customized CPU implementations do " - "not yet support dilation rates larger than 1.")); + if (std::is_same::value || + std::is_same::value) { + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), + errors::InvalidArgument( + "Current libxsmm and customized CPU implementations do " + "not yet support dilation rates larger than 1.")); + dilations_ = {1, 1, 1, 1}; + } } void Compute(OpKernelContext* context) override { @@ -434,8 +438,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { context, ConvBackpropComputeDimensionsV2( "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(), - filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1}, - strides_, padding_, explicit_paddings_, data_format_, &dims)); + filter_shape, out_backprop.shape(), dilations_, strides_, padding_, + explicit_paddings_, data_format_, &dims)); Tensor* filter_backprop; OP_REQUIRES_OK(context, diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 9073b323ec7..7d3160444d8 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -628,15 +628,17 @@ def _DepthwiseConv2dNativeGrad(op, grad): array_ops.shape(op.inputs[0]), op.inputs[1], grad, - op.get_attr("strides"), - op.get_attr("padding"), + dilations=op.get_attr("dilations"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), data_format=op.get_attr("data_format")), nn_ops.depthwise_conv2d_native_backprop_filter( op.inputs[0], array_ops.shape(op.inputs[1]), grad, - op.get_attr("strides"), - op.get_attr("padding"), + dilations=op.get_attr("dilations"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), data_format=op.get_attr("data_format")) ] diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 216c5754606..1435a0c6cbe 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -715,6 +715,22 @@ def zero_fraction(value, name=None): return array_ops.identity(zero_fraction_float32, "fraction") +# copybara:strip_begin +# TODO(b/138808492): Remove code inside copybara +# to make TPU code and CPU code consistent. +def _enclosing_tpu_context(): + # pylint: disable=protected-access + context = ops.get_default_graph()._get_control_flow_context() + # pylint: enable=protected-access + while context is not None and not isinstance( + context, control_flow_ops.XLAControlFlowContext): + context = context.outer_context + return context + + +# copybara:strip_end + + # pylint: disable=redefined-builtin @tf_export(v1=["nn.depthwise_conv2d"]) def depthwise_conv2d(input, @@ -774,6 +790,25 @@ def depthwise_conv2d(input, if rate is None: rate = [1, 1] + # copybara:strip_begin + # TODO(b/138808492): Remove code inside copybara + # to make TPU code and CPU code consistent. + # Use depthwise_conv2d_native if executing on TPU. + if _enclosing_tpu_context() is not None: + if data_format == "NCHW": + dilations = [1, 1, rate[0], rate[1]] + else: + dilations = [1, rate[0], rate[1], 1] + return nn_ops.depthwise_conv2d_native( + input=input, + filter=filter, + strides=strides, + padding=padding, + data_format=data_format, + dilations=dilations, + name=name) + # copybara:strip_end + def op(input_converted, _, padding): return nn_ops.depthwise_conv2d_native( input=input_converted, diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index f68b038c01f..98a4030641e 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -919,6 +920,22 @@ convolution_v2.__doc__ = deprecation.rewrite_argument_docstring( "filter", "filters") +# copybara:strip_begin +# TODO(b/138808492): Remove code inside copybara +# to make TPU code and CPU code consistent. +def _enclosing_tpu_context(): + # pylint: disable=protected-access + run_context = ops.get_default_graph()._get_control_flow_context() + # pylint: enable=protected-access + while run_context is not None and not isinstance( + run_context, control_flow_ops.XLAControlFlowContext): + run_context = run_context.outer_context + return run_context + + +# copybara:strip_end + + def convolution_internal( input, # pylint: disable=redefined-builtin filters, @@ -958,8 +975,14 @@ def convolution_internal( conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d} - if all(i == 1 for i in dilations): - # fast path if no dilation as gradient only supported on GPU for dilations + # copybara:strip_begin + # TODO(b/138808492): Remove code inside copybara + # to make TPU code and CPU code consistent. + if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations): + # fast path for TPU or if no dilation as gradient only supported on GPU + # for dilations + # copybara:strip_end + # copybara:insert if all(i == 1 for i in dilations): op = conv_ops[n] return op( input, @@ -1056,7 +1079,9 @@ class Convolution(object): self.filter_shape = filter_shape self.data_format = data_format self.strides = strides + self.padding = padding self.name = name + self.dilation_rate = dilation_rate self.conv_op = _WithSpaceToBatch( input_shape, dilation_rate=dilation_rate, @@ -1076,7 +1101,23 @@ class Convolution(object): name=self.name) def __call__(self, inp, filter): # pylint: disable=redefined-builtin - return self.conv_op(inp, filter) + # copybara:strip_begin + # TODO(b/138808492): Remove code inside copybara + # to make TPU code and CPU code consistent. + # TPU convolution supports dilations greater than 1. + if _enclosing_tpu_context() is not None: + return convolution_internal( + inp, + filter, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilations=self.dilation_rate, + name=self.name) + else: + return self.conv_op(inp, filter) + # copybara:strip_end + # copybara:insert return self.conv_op(inp, filter) @tf_export(v1=["nn.pool"])