diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 18b7a47fc8c..e01abc8133d 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -455,6 +455,58 @@ class Conv2DTest(test.TestCase):
         conv1,
         self.evaluate(conv2).reshape(conv1.shape))
 
+  @test_util.run_in_graph_and_eager_modes
+  def testConvolutionClass2DExpandedBatch(self):
+    tensor_in_sizes_batch = [10, 2, 3, 3]
+    tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
+    filter_in_sizes = [1, 1, 3, 3]
+    filter_in = self._CreateNumpyTensor(filter_in_sizes)
+    x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
+    x2 = x1.reshape(tensor_in_sizes_expanded_batch)
+    convolver1 = nn_ops.Convolution(
+        input_shape=x1.shape,
+        filter_shape=filter_in.shape,
+        strides=[1, 1],
+        padding="VALID")
+    self.assertEqual(convolver1.num_batch_dims, 1)
+    convolver2 = nn_ops.Convolution(
+        input_shape=x2.shape,
+        filter_shape=filter_in.shape,
+        strides=[1, 1],
+        padding="VALID")
+    self.assertEqual(convolver2.num_batch_dims, 2)
+    conv1 = convolver1(x1, filter_in)
+    conv2 = convolver2(x2, filter_in)
+    self.assertEqual(conv1.shape, tensor_in_sizes_batch)
+    self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
+    self.assertAllEqual(
+        conv1,
+        self.evaluate(conv2).reshape(conv1.shape))
+
+  @test_util.run_in_graph_and_eager_modes
+  def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self):
+    tensor_in_sizes_batch = [10, 2, 3, 3]
+    tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3]
+    filter_in_sizes = [1, 1, 3, 3]
+    filter_in = self._CreateNumpyTensor(filter_in_sizes)
+    x1 = self._CreateNumpyTensor(tensor_in_sizes_batch)
+    x2 = x1.reshape(tensor_in_sizes_expanded_batch)
+    conv1 = nn_ops.convolution(
+        x1,
+        filter_in,
+        strides=[1, 1],
+        padding="VALID")
+    conv2 = nn_ops.convolution(
+        x2,
+        filter_in,
+        strides=[1, 1],
+        padding="VALID")
+    self.assertEqual(conv1.shape, tensor_in_sizes_batch)
+    self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch)
+    self.assertAllEqual(
+        conv1,
+        self.evaluate(conv2).reshape(conv1.shape))
+
   @test_util.run_in_graph_and_eager_modes
   def testConv2D2x2Filter2x1Dilation(self):
     self._VerifyDilatedConvValues(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 4c00d085f82..24ee94fac48 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -131,9 +131,9 @@ def _non_atrous_convolution(
   """
   with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
     input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
-    input_shape = input.get_shape()
+    input_shape = input.shape
     filter = ops.convert_to_tensor(filter, name="filter")  # pylint: disable=redefined-builtin
-    filter_shape = filter.get_shape()
+    filter_shape = filter.shape
     op = _NonAtrousConvolution(
         input_shape,
         filter_shape=filter_shape,
@@ -148,36 +148,51 @@ class _NonAtrousConvolution(object):
   """Helper class for _non_atrous_convolution.
 
   Note that this class assumes that shapes of input and filter passed to
-  __call__ are compatible with input_shape and filter_shape passed to the
+  `__call__` are compatible with `input_shape` and filter_shape passed to the
   constructor.
 
   Arguments:
-    input_shape: static input shape, i.e. input.get_shape().
-    filter_shape: static filter shape, i.e. filter.get_shape().
+    input_shape: static input shape, i.e. input.shape.
+    filter_shape: static filter shape, i.e. filter.shape.
     padding: see _non_atrous_convolution.
     data_format: see _non_atrous_convolution.
     strides: see _non_atrous_convolution.
     name: see _non_atrous_convolution.
+    num_batch_dims: (Optional.)  The number of batch dimensions in the input;
+     if not provided, the default of `1` is used.
   """
 
   def __init__(
       self,
       input_shape,
-      filter_shape,  # pylint: disable=redefined-builtin
+      filter_shape,
       padding,
       data_format=None,
       strides=None,
-      name=None):
-    filter_shape = filter_shape.with_rank(input_shape.ndims)
+      name=None,
+      num_batch_dims=1):
+    # filter shape is always rank num_spatial_dims + 2
+    # and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
+    if input_shape.ndims is not None:
+      filter_shape = filter_shape.with_rank(
+          input_shape.ndims - num_batch_dims + 1)
     self.padding = padding
     self.name = name
-    input_shape = input_shape.with_rank(filter_shape.ndims)
+    # input shape is == num_spatial_dims + num_batch_dims + 1
+    # and filter_shape is always rank num_spatial_dims + 2
+    if filter_shape.ndims is not None:
+      input_shape = input_shape.with_rank(
+          filter_shape.ndims + num_batch_dims - 1)
     if input_shape.ndims is None:
-      raise ValueError("Rank of convolution must be known")
-    if input_shape.ndims < 3 or input_shape.ndims > 5:
       raise ValueError(
-          "`input` and `filter` must have rank at least 3 and at most 5")
-    conv_dims = input_shape.ndims - 2
+          "Rank of convolution must be known, but saw input_shape.ndims == {}"
+          .format(input_shape.ndims))
+    if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
+      raise ValueError(
+          "`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
+          "most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
+          .format(input_shape.ndims, num_batch_dims))
+    conv_dims = input_shape.ndims - num_batch_dims - 1
     if strides is None:
       strides = [1] * conv_dims
     elif len(strides) != conv_dims:
@@ -520,7 +535,7 @@ def with_space_to_batch(
 
   """
   input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
-  input_shape = input.get_shape()
+  input_shape = input.shape
 
   def build_op(num_spatial_dims, padding):
     return lambda inp, _: op(inp, num_spatial_dims, padding)
@@ -540,18 +555,19 @@ class _WithSpaceToBatch(object):
   """Helper class for with_space_to_batch.
 
   Note that this class assumes that shapes of input and filter passed to
-  __call__ are compatible with input_shape and filter_shape passed to the
-  constructor.
+  `__call__` are compatible with `input_shape`, `filter_shape`, and
+  `spatial_dims` passed to the constructor.
 
   Arguments
-    input_shape: static shape of input. i.e. input.get_shape().
-    dilation_rate: see with_space_to_batch
-    padding: see with_space_to_batch
+    input_shape: static shape of input. i.e. input.shape.
+    dilation_rate: see `with_space_to_batch`.
+    padding: see `with_space_to_batch`.
     build_op: Function that maps (num_spatial_dims, paddings) -> (function that
       maps (input, filter) -> output).
-    filter_shape: see with_space_to_batch
-    spatial_dims: see with_space_to_batch
-    data_format: see with_space_to_batch
+    filter_shape: see `with_space_to_batch`.
+    spatial_dims: `see with_space_to_batch`.
+    data_format: see `with_space_to_batch`.
+    num_batch_dims: (Optional).  Number of batch dims in `input_shape`.
   """
 
   def __init__(self,
@@ -561,24 +577,25 @@ class _WithSpaceToBatch(object):
                build_op,
                filter_shape=None,
                spatial_dims=None,
-               data_format=None):
+               data_format=None,
+               num_batch_dims=1):
     """Helper class for _with_space_to_batch."""
     dilation_rate = ops.convert_to_tensor(
         dilation_rate, dtypes.int32, name="dilation_rate")
-    try:
-      rate_shape = dilation_rate.get_shape().with_rank(1)
-    except ValueError:
-      raise ValueError("rate must be rank 1")
+    if dilation_rate.shape.ndims not in (None, 1):
+      raise ValueError(
+          "rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
 
-    if not dilation_rate.get_shape().is_fully_defined():
-      raise ValueError("rate must have known shape")
+    if not dilation_rate.shape.is_fully_defined():
+      raise ValueError("rate must have known shape, but saw {}"
+                       .format(dilation_rate.shape))
 
-    num_spatial_dims = rate_shape.dims[0].value
+    num_spatial_dims = dilation_rate.shape.dims[0].value
 
     if data_format is not None and data_format.startswith("NC"):
-      starting_spatial_dim = 2
+      starting_spatial_dim = num_batch_dims + 1
     else:
-      starting_spatial_dim = 1
+      starting_spatial_dim = num_batch_dims
 
     if spatial_dims is None:
       spatial_dims = range(starting_spatial_dim,
@@ -588,7 +605,7 @@ class _WithSpaceToBatch(object):
     if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
       raise ValueError(
           "spatial_dims must be a monotonically increasing sequence of "
-          "positive integers")
+          "positive integers, but saw: {}".format(orig_spatial_dims))
 
     if data_format is not None and data_format.startswith("NC"):
       expected_input_rank = spatial_dims[-1]
@@ -599,14 +616,16 @@ class _WithSpaceToBatch(object):
       input_shape.with_rank_at_least(expected_input_rank)
     except ValueError:
       raise ValueError(
-          "input tensor must have rank %d at least" % (expected_input_rank))
+          "input tensor must have rank at least {}, but saw rank {}"
+          .format(expected_input_rank, input_shape.ndims))
 
     const_rate = tensor_util.constant_value(dilation_rate)
     rate_or_const_rate = dilation_rate
     if const_rate is not None:
       rate_or_const_rate = const_rate
       if np.any(const_rate < 1):
-        raise ValueError("dilation_rate must be positive")
+        raise ValueError("dilation_rate must be positive, but saw: {}"
+                         .format(const_rate))
       if np.all(const_rate == 1):
         self.call = build_op(num_spatial_dims, padding)
         return
@@ -672,6 +691,7 @@ class _WithSpaceToBatch(object):
       filter_shape = array_ops.shape(filter)
       base_paddings = _with_space_to_batch_base_paddings(
           filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
+
     paddings, crops = array_ops.required_space_to_batch_paddings(
         input_shape=input_spatial_shape,
         base_paddings=base_paddings,
@@ -994,31 +1014,84 @@ def convolution_internal(
     data_format=None,
     dilations=None,
     name=None,
-    call_from_convolution=True):
-  """Internal function which performs rank agnostic convolution."""
-  if isinstance(input.shape, tensor_shape.TensorShape) and \
-        input.shape.rank is not None:
-    n = len(input.shape) - 2
-  elif not isinstance(input.shape, tensor_shape.TensorShape) and \
-        input.shape is not None:
-    n = len(input.shape) - 2
-  elif isinstance(filters.shape, tensor_shape.TensorShape) and \
-        filters.shape.rank is not None:
+    call_from_convolution=True,
+    num_spatial_dims=None):
+  """Internal function which performs rank agnostic convolution.
+
+  Args:
+    input: See `convolution`.
+    filters: See `convolution`.
+    strides: See `convolution`.
+    padding: See `convolution`.
+    data_format: See `convolution`.
+    dilations: See `convolution`.
+    name: See `convolution`.
+    call_from_convolution: See `convolution`.
+    num_spatial_dims: (Optional.).  It is a integer describing the
+      rank of the spatial dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
+      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
+      This argument is only required to disambiguate the rank of `batch_shape`
+      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
+      backwards compatibility, if `num_spatial_dims is None` and
+     `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
+     `1` (i.e., the input is expected to be
+     `[batch_size, num_channels] + input_spatial_shape`
+     or `[batch_size] + input_spatial_shape + [num_channels]`.
+
+  Returns:
+    A tensor of shape and dtype matching that of `input`.
+
+  Raises:
+    ValueError: If input and filter both have unknown shapes, or if
+      `num_spatial_dims` is provided and incompatible with the value
+      estimated from `filters.shape`.
+  """
+  n = None
+  if getattr(filters, 'shape', None) is None:
+    with ops.name_scope(name, 'convolution_internal', [filters, input]):
+      filters = ops.convert_to_tensor(filters, name='filters')
+  if (isinstance(filters.shape, tensor_shape.TensorShape)
+      and filters.shape.rank is not None):
     n = len(filters.shape) - 2
-  elif not isinstance(filters.shape, tensor_shape.TensorShape) and \
-        filters.shape is not None:
+  elif (not isinstance(filters.shape, tensor_shape.TensorShape)
+        and filters.shape is not None):
     n = len(filters.shape) - 2
+
+  if (isinstance(input.shape, tensor_shape.TensorShape)
+      and input.shape.rank is not None):
+    if n is None:
+      n = (num_spatial_dims if num_spatial_dims is not None
+           else len(input.shape) - 2)
+    num_batch_dims = len(input.shape) - n - 1
+  elif (not isinstance(input.shape, tensor_shape.TensorShape)
+        and input.shape is not None):
+    if n is None:
+      n = (num_spatial_dims if num_spatial_dims is not None
+           else len(input.shape) - 2)
+    num_batch_dims = len(input.shape) - n - 1
   else:
+    num_batch_dims = 1  # Default behavior if it cannot be estimated.
+
+  if n is None:
     raise ValueError("rank of input or filter must be known")
 
+  if num_spatial_dims is not None and n != num_spatial_dims:
+    raise ValueError(
+        "inconsistent estimate of spatial dims ({}) vs. actual passed "
+        "num_spatial_dims ({}).  n was estimated as len(filters.shape) - 2, "
+        "but filters shape is: {}".format(n, num_spatial_dims, filters.shape))
+
   if not 1 <= n <= 3:
     raise ValueError(
-        "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
+        "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
+        "of 1, 2 or 3 but saw {}.  num_batch_dims: {}."
+        .format(n, num_batch_dims))
 
   if data_format is None:
-    channel_index = n + 1
+    channel_index = num_batch_dims + n
   else:
-    channel_index = 1 if data_format.startswith("NC") else n + 1
+    channel_index = (
+        num_batch_dims if data_format.startswith("NC") else n + num_batch_dims)
 
   strides = _get_sequence(strides, n, channel_index, "strides")
   dilations = _get_sequence(dilations, n, channel_index, "dilations")
@@ -1031,7 +1104,7 @@ def convolution_internal(
     scope = "convolution"
 
   with ops.name_scope(name, scope, [input, filters]) as name:
-    conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
+    conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d}
 
     if device_context.enclosing_tpu_context() is not None or all(
         i == 1 for i in dilations):
@@ -1061,7 +1134,8 @@ def convolution_internal(
           strides=strides,
           dilation_rate=dilations,
           name=name,
-          data_format=data_format)
+          data_format=data_format,
+          num_spatial_dims=n)
       return op(input, filters)
 
 
@@ -1069,17 +1143,34 @@ class Convolution(object):
   """Helper class for convolution.
 
   Note that this class assumes that shapes of input and filter passed to
-  __call__ are compatible with input_shape and filter_shape passed to the
-  constructor.
+  `__call__` are compatible with `input_shape`, `filter_shape`, and
+  `num_spatial_dims` passed to the constructor.
 
   Arguments
-    input_shape: static shape of input. i.e. input.get_shape().
-    filter_shape: static shape of the filter. i.e. filter.get_shape().
-    padding:  see convolution.
+    input_shape: static shape of input. i.e. input.shape.  Its length is
+      `batch_shape + input_spatial_shape + [num_channels]` if `data_format`
+      does not start with `NC`, or
+      `batch_shape + [num_channels] + input_spatial_shape` if `data_format`
+      starts with `NC`.
+    filter_shape: static shape of the filter. i.e. filter.shape.
+    padding: The padding algorithm, must be "SAME" or "VALID".
     strides: see convolution.
     dilation_rate: see convolution.
     name: see convolution.
-    data_format: see convolution.
+    data_format: A string or `None`.  Specifies whether the channel dimension of
+      the `input` and output is the last dimension (if `data_format` is `None`
+      or does not start with `NC`), or the first post-batch dimension (i.e. if
+      `data_format` starts with `NC`).
+    num_spatial_dims: (Usually optional.)  Python integer, the rank of the
+      spatial and channel dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
+      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
+      This argument is only required to disambiguate the rank of `batch_shape`
+      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
+      backwards compatibility, if `num_spatial_dims is None` and
+      `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
+      `1` (i.e., the input is expected to be
+      `[batch_size, num_channels] + input_spatial_shape`
+      or `[batch_size] + input_spatial_shape + [num_channels]`.
   """
 
   def __init__(self,
@@ -1089,40 +1180,72 @@ class Convolution(object):
                strides=None,
                dilation_rate=None,
                name=None,
-               data_format=None):
+               data_format=None,
+               num_spatial_dims=None):
     """Helper function for convolution."""
-    num_total_dims = filter_shape.ndims
-    if num_total_dims is None:
-      num_total_dims = input_shape.ndims
-    if num_total_dims is None:
-      raise ValueError("rank of input or filter must be known")
+    num_batch_dims = None
+    filter_shape = tensor_shape.as_shape(filter_shape)
+    input_shape = tensor_shape.as_shape(input_shape)
 
-    num_spatial_dims = num_total_dims - 2
+    if filter_shape.ndims is not None:
+      if (num_spatial_dims is not None and
+          filter_shape.ndims != num_spatial_dims + 2):
+        raise ValueError(
+            "Expected filter_shape.ndims == num_spatial_dims + 2, "
+            "but saw filter_shape.ndims == {} and num_spatial_dims == {}"
+            .format(filter_shape.ndims, num_spatial_dims))
+      else:
+        num_spatial_dims = filter_shape.ndims - 2
 
-    try:
-      input_shape.with_rank(num_spatial_dims + 2)
-    except ValueError:
+    if input_shape.ndims is not None and num_spatial_dims is not None:
+      num_batch_dims = input_shape.ndims - num_spatial_dims - 1
+
+    if num_spatial_dims is None:
+      num_spatial_dims = input_shape.ndims - 2
+    else:
+      if input_shape.ndims is not None:
+        if input_shape.ndims < num_spatial_dims + 2:
+          raise ValueError(
+              "Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
+              "input_shape.ndims == {} and num_spatial_dims == {}"
+              .format(input_shape.ndims, num_spatial_dims))
+        else:
+          if num_batch_dims is None:
+            num_batch_dims = input_shape.ndims - num_spatial_dims - 1
+
+    if num_spatial_dims is None:
       raise ValueError(
-          "input tensor must have rank %d" % (num_spatial_dims + 2))
+          "Cannot estimate num_spatial_dims since input_shape.ndims is None, "
+          "filter_shape.ndims is None, and argument num_spatial_dims is also "
+          "None.")
 
-    try:
-      filter_shape.with_rank(num_spatial_dims + 2)
-    except ValueError:
+    if num_batch_dims is None:
+      num_batch_dims = 1
+
+    if num_batch_dims < 1:
       raise ValueError(
-          "filter tensor must have rank %d" % (num_spatial_dims + 2))
+          "num_batch_dims should be >= 1, but saw {}.  num_batch_dims was "
+          "estimated as `input_shape.ndims - num_spatial_dims - 1` and "
+          "num_spatial_dims was either provided or estimated as "
+          "`filter_shape.ndims - 2`.  input_shape.ndims: {}, "
+          "num_spatial_dims: {}, filter_shape.ndims: {}"
+          .format(num_batch_dims, input_shape.ndims, num_spatial_dims,
+                  filter_shape.ndims))
 
     if data_format is None or not data_format.startswith("NC"):
       input_channels_dim = tensor_shape.dimension_at_index(
-          input_shape, num_spatial_dims + 1)
-      spatial_dims = range(1, num_spatial_dims + 1)
+          input_shape, num_spatial_dims + num_batch_dims)
+      spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
     else:
-      input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1)
-      spatial_dims = range(2, num_spatial_dims + 2)
+      input_channels_dim = tensor_shape.dimension_at_index(
+          input_shape, num_batch_dims)
+      spatial_dims = range(
+          num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
 
     if not input_channels_dim.is_compatible_with(
         filter_shape[num_spatial_dims]):
       raise ValueError(
-          "number of input channels does not match corresponding dimension of "
+          "Number of input channels does not match corresponding dimension of "
           "filter, {} != {}".format(input_channels_dim,
                                     filter_shape[num_spatial_dims]))
 
@@ -1136,6 +1259,8 @@ class Convolution(object):
     self.padding = padding
     self.name = name
     self.dilation_rate = dilation_rate
+    self.num_batch_dims = num_batch_dims
+    self.num_spatial_dims = num_spatial_dims
     self.conv_op = _WithSpaceToBatch(
         input_shape,
         dilation_rate=dilation_rate,
@@ -1143,7 +1268,8 @@ class Convolution(object):
         build_op=self._build_op,
         filter_shape=filter_shape,
         spatial_dims=spatial_dims,
-        data_format=data_format)
+        data_format=data_format,
+        num_batch_dims=num_batch_dims)
 
   def _build_op(self, _, padding):
     return _NonAtrousConvolution(
@@ -1152,7 +1278,8 @@ class Convolution(object):
         padding=padding,
         data_format=self.data_format,
         strides=self.strides,
-        name=self.name)
+        name=self.name,
+        num_batch_dims=self.num_batch_dims)
 
   def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
     # TPU convolution supports dilations greater than 1.
@@ -1165,7 +1292,8 @@ class Convolution(object):
           data_format=self.data_format,
           dilations=self.dilation_rate,
           name=self.name,
-          call_from_convolution=False)
+          call_from_convolution=False,
+          num_spatial_dims=self.num_spatial_dims)
     else:
       return self.conv_op(inp, filter)
 
@@ -2392,6 +2520,42 @@ def conv2d_transpose_v2(
         name=name)
 
 
+def _conv2d_expanded_batch(
+    input,  # pylint: disable=redefined-builtin
+    filters,
+    strides,
+    padding,
+    data_format,
+    dilations,
+    name):
+  """Helper function for `convolution_internal`; handles expanded batches."""
+  # Try really hard to avoid modifying the legacy name scopes - return early.
+  shape = getattr(input, "shape", None)
+  if shape is not None:
+    ndims = getattr(shape, "ndims", -1)
+    if ndims == -1: ndims = len(shape)
+  if ndims in (4, 3, 2, 1, 0, None):
+    return gen_nn_ops.conv2d(
+        input,
+        filter=filters,
+        strides=strides,
+        padding=padding,
+        data_format=data_format,
+        dilations=dilations,
+        name=name)
+  return _squeeze_batch_dims(
+      input,
+      functools.partial(
+          gen_nn_ops.conv2d,
+          filter=filters,
+          strides=strides,
+          padding=padding,
+          data_format=data_format,
+          dilations=dilations),
+      inner_rank=3,
+      name=name)
+
+
 @tf_export("nn.atrous_conv2d_transpose")
 @dispatch.add_dispatch_support
 def atrous_conv2d_transpose(value,