diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index 80855da2e92..85ac08a1eb7 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -356,50 +356,51 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
       random_seed.set_random_seed(200)
       inputs = array_ops.constant([1., 2.])
       outputs = array_ops.constant([3., 4.])
-      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), inputs,
-                                           outputs)
+      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
+      block.register_additional_minibatch(inputs, outputs)
 
-      self.assertAllEqual(outputs, block.tensors_to_compute_grads())
+      self.assertAllEqual([outputs], block.tensors_to_compute_grads())
 
   def testInstantiateFactorsHasBias(self):
     with ops.Graph().as_default():
       random_seed.set_random_seed(200)
       inputs = array_ops.constant([[1., 2.], [3., 4.]])
       outputs = array_ops.constant([[3., 4.], [5., 6.]])
-      block = fb.FullyConnectedKFACBasicFB(
-          lc.LayerCollection(), inputs, outputs, has_bias=True)
+      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
+      block.register_additional_minibatch(inputs, outputs)
 
       grads = outputs**2
-      block.instantiate_factors((grads,), 0.5)
+      block.instantiate_factors(([grads],), 0.5)
 
   def testInstantiateFactorsNoBias(self):
     with ops.Graph().as_default():
       random_seed.set_random_seed(200)
       inputs = array_ops.constant([[1., 2.], [3., 4.]])
       outputs = array_ops.constant([[3., 4.], [5., 6.]])
-      block = fb.FullyConnectedKFACBasicFB(
-          lc.LayerCollection(), inputs, outputs, has_bias=False)
+      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+      block.register_additional_minibatch(inputs, outputs)
 
       grads = outputs**2
-      block.instantiate_factors((grads,), 0.5)
+      block.instantiate_factors(([grads],), 0.5)
 
   def testMultiplyInverseTuple(self):
     with ops.Graph().as_default(), self.test_session() as sess:
       random_seed.set_random_seed(200)
       inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
       outputs = array_ops.constant([[3., 4.], [5., 6.]])
-      block = fb.FullyConnectedKFACBasicFB(
-          lc.LayerCollection(), inputs, outputs, has_bias=False)
+      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+      block.register_additional_minibatch(inputs, outputs)
       grads = outputs**2
-      block.instantiate_factors((grads,), 0.5)
+      block.instantiate_factors(([grads],), 0.5)
 
       # Make sure our inverse is something other than the identity.
       sess.run(tf_variables.global_variables_initializer())
       sess.run(block._input_factor.make_inverse_update_ops())
       sess.run(block._output_factor.make_inverse_update_ops())
 
-      vector = (np.arange(2, 6).reshape(2, 2).astype(np.float32), np.arange(
-          1, 3).reshape(2, 1).astype(np.float32))
+      vector = (
+          np.arange(2, 6).reshape(2, 2).astype(np.float32),  #
+          np.arange(1, 3).reshape(2, 1).astype(np.float32))
       output = block.multiply_inverse((array_ops.constant(vector[0]),
                                        array_ops.constant(vector[1])))
 
@@ -413,10 +414,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
       random_seed.set_random_seed(200)
       inputs = array_ops.constant([[1., 2.], [3., 4.]])
       outputs = array_ops.constant([[3., 4.], [5., 6.]])
-      block = fb.FullyConnectedKFACBasicFB(
-          lc.LayerCollection(), inputs, outputs, has_bias=False)
+      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+      block.register_additional_minibatch(inputs, outputs)
       grads = outputs**2
-      block.instantiate_factors((grads,), 0.5)
+      block.instantiate_factors(([grads],), 0.5)
 
       # Make sure our inverse is something other than the identity.
       sess.run(tf_variables.global_variables_initializer())
@@ -436,11 +437,11 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
       inputs = array_ops.zeros([32, input_dim])
       outputs = array_ops.zeros([32, output_dim])
       params = array_ops.zeros([input_dim, output_dim])
-      block = fb.FullyConnectedKFACBasicFB(
-          lc.LayerCollection(), inputs, outputs, has_bias=False)
+      block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+      block.register_additional_minibatch(inputs, outputs)
       grads = outputs**2
       damping = 0.  # This test is only valid without damping.
-      block.instantiate_factors((grads,), damping)
+      block.instantiate_factors(([grads],), damping)
 
       sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
       sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 5e822b5fe32..754c2cc853b 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -367,7 +367,7 @@ class ConvDiagonalFB(FisherBlock):
         (self._strides[1] * self._strides[2]))
 
     if NORMALIZE_DAMPING_POWER:
-      damping /= self._num_locations**NORMALIZE_DAMPING_POWER
+      damping /= self._num_locations ** NORMALIZE_DAMPING_POWER
     self._damping = damping
 
     self._factor = self._layer_collection.make_or_get_factor(
@@ -478,34 +478,60 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
   K-FAC paper (https://arxiv.org/abs/1503.05671)
   """
 
-  def __init__(self, layer_collection, inputs, outputs, has_bias=False):
+  def __init__(self, layer_collection, has_bias=False):
     """Creates a FullyConnectedKFACBasicFB block.
 
     Args:
       layer_collection: The collection of all layers in the K-FAC approximate
           Fisher information matrix to which this FisherBlock belongs.
-      inputs: The Tensor of input activations to this layer.
-      outputs: The Tensor of output pre-activations from this layer.
       has_bias: Whether the component Kronecker factors have an additive bias.
           (Default: False)
     """
-    self._inputs = inputs
-    self._outputs = outputs
+    self._inputs = []
+    self._outputs = []
     self._has_bias = has_bias
 
     super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
-    self._input_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.FullyConnectedKroneckerFactor,
-        ((self._inputs,), self._has_bias))
-    self._output_factor = self._layer_collection.make_or_get_factor(
-        fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
+    """Instantiate Kronecker Factors for this FisherBlock.
+
+    Args:
+      grads_list: List of list of Tensors. grads_list[i][j] is the
+        gradient of the loss with respect to 'outputs' from source 'i' and
+        tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
+      damping: 0-D Tensor or float. 'damping' * identity is approximately added
+        to this FisherBlock's Fisher approximation.
+    """
+    # TODO(b/68033310): Validate which of,
+    #   (1) summing on a single device (as below), or
+    #   (2) on each device in isolation and aggregating
+    # is faster.
+    inputs = _concat_along_batch_dim(self._inputs)
+    grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
+
+    self._input_factor = self._layer_collection.make_or_get_factor(  #
+        fisher_factors.FullyConnectedKroneckerFactor,  #
+        ((inputs,), self._has_bias))
+    self._output_factor = self._layer_collection.make_or_get_factor(  #
+        fisher_factors.FullyConnectedKroneckerFactor,  #
+        (grads_list,))
     self._register_damped_input_and_output_inverses(damping)
 
   def tensors_to_compute_grads(self):
     return self._outputs
 
+  def register_additional_minibatch(self, inputs, outputs):
+    """Registers an additional minibatch to the FisherBlock.
+
+    Args:
+      inputs: Tensor of shape [batch_size, input_size]. Inputs to the
+        matrix-multiply.
+      outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
+    """
+    self._inputs.append(inputs)
+    self._outputs.append(outputs)
+
 
 class ConvKFCBasicFB(KroneckerProductFB):
   """FisherBlock for 2D convolutional layers using the basic KFC approx.
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 86a1782fcf0..b8b524406c3 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -573,6 +573,14 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
   """
 
   def __init__(self, tensors, has_bias=False):
+    """Instantiate FullyConnectedKroneckerFactor.
+
+    Args:
+      tensors: List of Tensors of shape [batch_size, n]. Represents either a
+        layer's inputs or its output's gradients.
+      has_bias: bool. If True, assume this factor is for the layer's inputs and
+        append '1' to each row.
+    """
     # The tensor argument is either a tensor of input activations or a tensor of
     # output pre-activation gradients.
     self._has_bias = has_bias
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 10ef5543516..ceb1131f286 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -255,9 +255,9 @@ class LayerCollection(object):
                                approx=APPROX_KRONECKER_NAME):
     has_bias = isinstance(params, (tuple, list))
     if approx == APPROX_KRONECKER_NAME:
-      self.register_block(params,
-                          fb.FullyConnectedKFACBasicFB(self, inputs, outputs,
-                                                       has_bias))
+      block = fb.FullyConnectedKFACBasicFB(self, has_bias)
+      block.register_additional_minibatch(inputs, outputs)
+      self.register_block(params, block)
     elif approx == APPROX_DIAGONAL_NAME:
       block = fb.FullyConnectedDiagonalFB(self, has_bias)
       block.register_additional_minibatch(inputs, outputs)