diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index 9b13756e62f..80855da2e92 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -328,17 +328,11 @@ class FullyConnectedDiagonalFB(test.TestCase):
       multiply_result: Result of FisherBlock.multiply(params)
       multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
     """
-
-    def _as_tensors(tensor_or_tuple):
-      if isinstance(tensor_or_tuple, (tuple, list)):
-        return tuple(ops.convert_to_tensor(t) for t in tensor_or_tuple)
-      return ops.convert_to_tensor(tensor_or_tuple)
-
     with ops.Graph().as_default(), self.test_session() as sess:
-      inputs = [_as_tensors(i) for i in inputs]
-      outputs = [_as_tensors(o) for o in outputs]
-      output_grads = [_as_tensors(og) for og in output_grads]
-      params = _as_tensors(params)
+      inputs = as_tensors(inputs)
+      outputs = as_tensors(outputs)
+      output_grads = as_tensors(output_grads)
+      params = as_tensors(params)
 
       block = fb.FullyConnectedDiagonalFB(
           lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
@@ -464,6 +458,188 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
       self.assertAllClose(output_flat, explicit)
 
 
+class ConvDiagonalFBTest(test.TestCase):
+
+  def setUp(self):
+    super(ConvDiagonalFBTest, self).setUp()
+
+    self.batch_size = 2
+    self.height = 8
+    self.width = 4
+    self.input_channels = 6
+    self.output_channels = 3
+    self.kernel_size = 1
+
+    self.inputs = np.random.randn(self.batch_size, self.height, self.width,
+                                  self.input_channels).astype(np.float32)
+    self.outputs = np.zeros(
+        [self.batch_size, self.height, self.width,
+         self.output_channels]).astype(np.float32)
+    self.output_grads = np.random.randn(
+        self.batch_size, self.height, self.width, self.output_channels).astype(
+            np.float32)
+    self.w = np.random.randn(self.kernel_size, self.kernel_size,
+                             self.input_channels, self.output_channels).astype(
+                                 np.float32)
+    self.b = np.random.randn(self.output_channels).astype(np.float32)
+
+  def fisherApprox(self, has_bias=False):
+    """Fisher approximation using default inputs."""
+    if has_bias:
+      inputs = np.concatenate(
+          [self.inputs,
+           np.ones([self.batch_size, self.height, self.width, 1])],
+          axis=-1)
+    else:
+      inputs = self.inputs
+    return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
+                                                 self.kernel_size)
+
+  def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
+    r"""Builds explicit diagonal Fisher approximation.
+
+    Fisher's diagonal is (d loss / d w)'s elements squared for
+      d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
+
+    where the expectation is taken over examples and the sum over (x, y)
+    locations upon which the convolution is applied.
+
+    Args:
+      inputs: np.array of shape [batch_size, height, width, input_channels].
+      output_grads: np.array of shape [batch_size, height, width,
+        output_channels].
+      kernel_size: int. height and width of kernel.
+
+    Returns:
+      Diagonal np.array of shape [num_params, num_params] for num_params =
+      kernel_size^2 * input_channels * output_channels.
+    """
+    batch_size, height, width, input_channels = inputs.shape
+    assert output_grads.shape[0] == batch_size
+    assert output_grads.shape[1] == height
+    assert output_grads.shape[2] == width
+    output_channels = output_grads.shape[3]
+
+    # If kernel_size == 1, then we don't need to worry about capturing context
+    # around the pixel upon which a convolution is applied. This makes testing
+    # easier.
+    assert kernel_size == 1, "kernel_size != 1 isn't supported."
+    num_locations = height * width
+    inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
+    output_grads = np.reshape(output_grads,
+                              [batch_size, num_locations, output_channels])
+
+    fisher_diag = np.zeros((input_channels, output_channels))
+    for i in range(batch_size):
+      # Each example's approximation is a square(sum-of-outer-products).
+      example_fisher_diag = np.zeros((input_channels, output_channels))
+      for j in range(num_locations):
+        example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
+      fisher_diag += np.square(example_fisher_diag)
+
+    # Normalize by batch_size (not num_locations).
+    return np.diag(fisher_diag.flatten()) / batch_size
+
+  def testMultiply(self):
+    result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+                                       [self.output_grads])
+
+    # Construct Fisher-vector product.
+    expected_result = self.fisherApprox().dot(self.w.flatten())
+    expected_result = expected_result.reshape([
+        self.kernel_size, self.kernel_size, self.input_channels,
+        self.output_channels
+    ])
+
+    self.assertAllClose(expected_result, result)
+
+  def testMultiplyInverse(self):
+    _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+                                       [self.output_grads])
+
+    # Construct inverse Fisher-vector product.
+    expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+    expected_result = expected_result.reshape([
+        self.kernel_size, self.kernel_size, self.input_channels,
+        self.output_channels
+    ])
+
+    self.assertAllClose(expected_result, result, atol=1e-3)
+
+  def testRegisterAdditionalMinibatch(self):
+    """Ensure 1 big minibatch and 2 small minibatches are equivalent."""
+    multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+        self.w, [self.inputs], [self.outputs], [self.output_grads])
+    multiply_result_small, multiply_inverse_result_small = (
+        self.runFisherBlockOps(self.w,
+                               np.split(self.inputs, 2),
+                               np.split(self.outputs, 2),
+                               np.split(self.output_grads, 2)))
+
+    self.assertAllClose(multiply_result_big, multiply_result_small)
+    self.assertAllClose(multiply_inverse_result_big,
+                        multiply_inverse_result_small)
+
+  def testMultiplyHasBias(self):
+    result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+                                       [self.outputs], [self.output_grads])
+    # Clone 'b' along 'input_channels' dimension.
+    b_filter = np.tile(
+        np.reshape(self.b, [1, 1, 1, self.output_channels]),
+        [self.kernel_size, self.kernel_size, 1, 1])
+    params = np.concatenate([self.w, b_filter], axis=2)
+    expected_result = self.fisherApprox(True).dot(params.flatten())
+
+    # Extract 'b' from concatenated parameters.
+    expected_result = expected_result.reshape([
+        self.kernel_size, self.kernel_size, self.input_channels + 1,
+        self.output_channels
+    ])
+    expected_result = (expected_result[:, :, 0:-1, :], np.reshape(
+        expected_result[:, :, -1, :], [self.output_channels]))
+
+    self.assertEqual(len(result), 2)
+    self.assertAllClose(expected_result[0], result[0])
+    self.assertAllClose(expected_result[1], result[1])
+
+  def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+    """Run Ops guaranteed by FisherBlock interface.
+
+    Args:
+      params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+        bias of this layer.
+      inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+        layer.
+      outputs: list of Tensors of shape [batch_size, output_size].
+        Preactivations produced by layer.
+      output_grads: list of Tensors of shape [batch_size, output_size].
+        Gradient of loss with respect to 'outputs'.
+
+    Returns:
+      multiply_result: Result of FisherBlock.multiply(params)
+      multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+    """
+    with ops.Graph().as_default(), self.test_session() as sess:
+      inputs = as_tensors(inputs)
+      outputs = as_tensors(outputs)
+      output_grads = as_tensors(output_grads)
+      params = as_tensors(params)
+
+      block = fb.ConvDiagonalFB(
+          lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
+      for (i, o) in zip(inputs, outputs):
+        block.register_additional_minibatch(i, o)
+
+      block.instantiate_factors((output_grads,), damping=0.0)
+
+      sess.run(tf_variables.global_variables_initializer())
+      sess.run(block._factor.make_covariance_update_op(0.0))
+      multiply_result = sess.run(block.multiply(params))
+      multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+    return multiply_result, multiply_inverse_result
+
+
 class ConvKFCBasicFBTest(test.TestCase):
 
   def _testConvKFCBasicFBInitParams(self, params):
@@ -583,5 +759,11 @@ class ConvKFCBasicFBTest(test.TestCase):
       self.assertAllClose(output_flat, explicit)
 
 
+def as_tensors(tensor_or_tuple):
+  """Converts a potentially nested tuple of np.array to Tensors."""
+  if isinstance(tensor_or_tuple, (tuple, list)):
+    return tuple(as_tensors(t) for t in tensor_or_tuple)
+  return ops.convert_to_tensor(tensor_or_tuple)
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index 53d40da586c..b444e871701 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -89,6 +89,10 @@ class LayerCollectionTest(test.TestCase):
       lc.register_conv2d(
           array_ops.constant(4), [1, 1, 1, 1], 'SAME',
           array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
+      lc.register_conv2d(
+          array_ops.constant(4), [1, 1, 1, 1], 'SAME',
+          array_ops.ones((1, 1, 1, 1)), array_ops.constant(3),
+          approx=layer_collection.APPROX_DIAGONAL_NAME)
       lc.register_generic(
           array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
       lc.register_generic(
@@ -96,7 +100,7 @@ class LayerCollectionTest(test.TestCase):
           16,
           approx=layer_collection.APPROX_DIAGONAL_NAME)
 
-      self.assertEqual(5, len(lc.get_blocks()))
+      self.assertEqual(6, len(lc.get_blocks()))
 
   def testRegisterBlocksMultipleRegistrations(self):
     with ops.Graph().as_default():
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 6cca2272d7d..5e822b5fe32 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -227,7 +227,7 @@ class FullyConnectedDiagonalFB(FisherBlock):
   'w'. For an example 'x' that produces layer inputs 'a' and output
   preactivations 's',
 
-    v(x, y, w) = vec( x (d loss / d s)^T )
+    v(x, y, w) = vec( a (d loss / d s)^T )
 
   This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
   to the layer's parameters 'w'.
@@ -309,13 +309,29 @@ class FullyConnectedDiagonalFB(FisherBlock):
 class ConvDiagonalFB(FisherBlock):
   """FisherBlock for convolutional layers using a diagonal approx.
 
-  Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator.
+  Estimates the Fisher Information matrix's diagonal entries for a convolutional
+  layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
+  estimator.
+
+  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
+  into it. We are interested in Fisher(params)[i, i]. This is,
+
+    Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
+                         = E[ v(x, y, params)[i] ^ 2 ]
+
+  Consider a convoluational layer in this model with (unshared) filter matrix
+  'w'. For an example image 'x' that produces layer inputs 'a' and output
+  preactivations 's',
+
+    v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )
+
+  where 'loc' is a single (x, y) location in an image.
+
+  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
+  to the layer's parameters 'w'.
   """
 
-  # TODO(jamesmartens): add units tests for this class
-
-  def __init__(self, layer_collection, params, inputs, outputs, strides,
-               padding):
+  def __init__(self, layer_collection, params, strides, padding):
     """Creates a ConvDiagonalFB block.
 
     Args:
@@ -325,37 +341,39 @@ class ConvDiagonalFB(FisherBlock):
         kernel alone, a Tensor of shape [kernel_height, kernel_width,
         in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
         containing the previous and a Tensor of shape [out_channels].
-      inputs: A Tensor of shape [batch_size, height, width, in_channels].
-        Input activations to this layer.
-      outputs: A Tensor of shape [batch_size, height, width, out_channels].
-        Output pre-activations from this layer.
       strides: The stride size in this layer (1-D Tensor of length 4).
-      padding: The padding in this layer (1-D of Tensor length 4).
+      padding: The padding in this layer (e.g. "SAME").
     """
-    self._inputs = inputs
-    self._outputs = outputs
-    self._strides = strides
+    self._inputs = []
+    self._outputs = []
+    self._strides = tuple(strides) if isinstance(strides, list) else strides
     self._padding = padding
     self._has_bias = isinstance(params, (tuple, list))
 
     fltr = params[0] if self._has_bias else params
     self._filter_shape = tuple(fltr.shape.as_list())
 
-    input_shape = tuple(inputs.shape.as_list())
-    self._num_locations = (
-        input_shape[1] * input_shape[2] // (strides[1] * strides[2]))
-
     super(ConvDiagonalFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
+    # Concatenate inputs, grads_list into single Tensors.
+    inputs = _concat_along_batch_dim(self._inputs)
+    grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
+
+    # Infer number of locations upon which convolution is applied.
+    inputs_shape = tuple(inputs.shape.as_list())
+    self._num_locations = (
+        inputs_shape[1] * inputs_shape[2] //
+        (self._strides[1] * self._strides[2]))
+
     if NORMALIZE_DAMPING_POWER:
       damping /= self._num_locations**NORMALIZE_DAMPING_POWER
     self._damping = damping
 
     self._factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvDiagonalFactor,
-        (self._inputs, grads_list, self._filter_shape, self._strides,
-         self._padding, self._has_bias))
+        (inputs, grads_list, self._filter_shape, self._strides, self._padding,
+         self._has_bias))
 
   def multiply_inverse(self, vector):
     reshaped_vect = utils.layer_params_to_mat2d(vector)
@@ -370,6 +388,18 @@ class ConvDiagonalFB(FisherBlock):
   def tensors_to_compute_grads(self):
     return self._outputs
 
+  def register_additional_minibatch(self, inputs, outputs):
+    """Registers an additional minibatch to the FisherBlock.
+
+    Args:
+      inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
+        the convolution.
+      outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
+        preactivations.
+    """
+    self._inputs.append(inputs)
+    self._outputs.append(outputs)
+
 
 class KroneckerProductFB(FisherBlock):
   """A base class for FisherBlocks with separate input and output factors.
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index beb8ef136e3..10ef5543516 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -273,9 +273,9 @@ class LayerCollection(object):
                           fb.ConvKFCBasicFB(self, params, inputs, outputs,
                                             strides, padding))
     elif approx == APPROX_DIAGONAL_NAME:
-      self.register_block(params,
-                          fb.ConvDiagonalFB(self, params, inputs, outputs,
-                                            strides, padding))
+      block = fb.ConvDiagonalFB(self, params, strides, padding)
+      block.register_additional_minibatch(inputs, outputs)
+      self.register_block(params, block)
 
   def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
     params = params if isinstance(params, (tuple, list)) else (params,)
@@ -379,6 +379,27 @@ class LayerCollection(object):
     self._loss_dict[name] = loss
 
   def make_or_get_factor(self, cls, args):
+    """Insert 'cls(args)' into 'self.fisher_factors' if not already present.
+
+    Wraps constructor in 'tf.variable_scope()' to ensure variables constructed
+    in 'cls.__init__' are placed under this LayerCollection's scope.
+
+    Args:
+      cls: Class that implements FisherFactor.
+      args: Tuple of arguments to pass into 'cls's constructor. Must be
+        hashable.
+
+    Returns:
+      Instance of 'cls' found in self.fisher_factors.
+    """
+    try:
+      hash(args)
+    except TypeError:
+      raise TypeError((
+          "Unable to use (cls, args) = ({}, {}) as a key in "
+          "LayerCollection.fisher_factors. The pair cannot be hashed."
+      ).format(cls, args))
+
     with variable_scope.variable_scope(self._var_scope):
       return utils.setdefault(self.fisher_factors, (cls, args),
                               lambda: cls(*args))