Multi-minibatch support for

tf.contrib.kfac.fisher_blocks.FullyConnectedKFACBasicFB.

PiperOrigin-RevId: 173109677
This commit is contained in:
A. Unique TensorFlower 2017-10-23 06:53:02 -07:00 committed by TensorFlower Gardener
parent dc13a8e2f7
commit 670dddf4ad
4 changed files with 69 additions and 34 deletions

View File

@ -356,50 +356,51 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
random_seed.set_random_seed(200)
inputs = array_ops.constant([1., 2.])
outputs = array_ops.constant([3., 4.])
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), inputs,
outputs)
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
block.register_additional_minibatch(inputs, outputs)
self.assertAllEqual(outputs, block.tensors_to_compute_grads())
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
def testInstantiateFactorsHasBias(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
inputs = array_ops.constant([[1., 2.], [3., 4.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(
lc.LayerCollection(), inputs, outputs, has_bias=True)
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
def testInstantiateFactorsNoBias(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
inputs = array_ops.constant([[1., 2.], [3., 4.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(
lc.LayerCollection(), inputs, outputs, has_bias=False)
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
def testMultiplyInverseTuple(self):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(
lc.LayerCollection(), inputs, outputs, has_bias=False)
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
sess.run(block._input_factor.make_inverse_update_ops())
sess.run(block._output_factor.make_inverse_update_ops())
vector = (np.arange(2, 6).reshape(2, 2).astype(np.float32), np.arange(
1, 3).reshape(2, 1).astype(np.float32))
vector = (
np.arange(2, 6).reshape(2, 2).astype(np.float32), #
np.arange(1, 3).reshape(2, 1).astype(np.float32))
output = block.multiply_inverse((array_ops.constant(vector[0]),
array_ops.constant(vector[1])))
@ -413,10 +414,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
random_seed.set_random_seed(200)
inputs = array_ops.constant([[1., 2.], [3., 4.]])
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedKFACBasicFB(
lc.LayerCollection(), inputs, outputs, has_bias=False)
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -436,11 +437,11 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
inputs = array_ops.zeros([32, input_dim])
outputs = array_ops.zeros([32, output_dim])
params = array_ops.zeros([input_dim, output_dim])
block = fb.FullyConnectedKFACBasicFB(
lc.LayerCollection(), inputs, outputs, has_bias=False)
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
block.instantiate_factors((grads,), damping)
block.instantiate_factors(([grads],), damping)
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))

View File

@ -367,7 +367,7 @@ class ConvDiagonalFB(FisherBlock):
(self._strides[1] * self._strides[2]))
if NORMALIZE_DAMPING_POWER:
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
damping /= self._num_locations ** NORMALIZE_DAMPING_POWER
self._damping = damping
self._factor = self._layer_collection.make_or_get_factor(
@ -478,34 +478,60 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
K-FAC paper (https://arxiv.org/abs/1503.05671)
"""
def __init__(self, layer_collection, inputs, outputs, has_bias=False):
def __init__(self, layer_collection, has_bias=False):
"""Creates a FullyConnectedKFACBasicFB block.
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
inputs: The Tensor of input activations to this layer.
outputs: The Tensor of output pre-activations from this layer.
has_bias: Whether the component Kronecker factors have an additive bias.
(Default: False)
"""
self._inputs = inputs
self._outputs = outputs
self._inputs = []
self._outputs = []
self._has_bias = has_bias
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedKroneckerFactor,
((self._inputs,), self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
"""Instantiate Kronecker Factors for this FisherBlock.
Args:
grads_list: List of list of Tensors. grads_list[i][j] is the
gradient of the loss with respect to 'outputs' from source 'i' and
tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
# TODO(b/68033310): Validate which of,
# (1) summing on a single device (as below), or
# (2) on each device in isolation and aggregating
# is faster.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
self._input_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
((inputs,), self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
(grads_list,))
self._register_damped_input_and_output_inverses(damping)
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
matrix-multiply.
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
class ConvKFCBasicFB(KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.

View File

@ -573,6 +573,14 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
"""
def __init__(self, tensors, has_bias=False):
"""Instantiate FullyConnectedKroneckerFactor.
Args:
tensors: List of Tensors of shape [batch_size, n]. Represents either a
layer's inputs or its output's gradients.
has_bias: bool. If True, assume this factor is for the layer's inputs and
append '1' to each row.
"""
# The tensor argument is either a tensor of input activations or a tensor of
# output pre-activation gradients.
self._has_bias = has_bias

View File

@ -255,9 +255,9 @@ class LayerCollection(object):
approx=APPROX_KRONECKER_NAME):
has_bias = isinstance(params, (tuple, list))
if approx == APPROX_KRONECKER_NAME:
self.register_block(params,
fb.FullyConnectedKFACBasicFB(self, inputs, outputs,
has_bias))
block = fb.FullyConnectedKFACBasicFB(self, has_bias)
block.register_additional_minibatch(inputs, outputs)
self.register_block(params, block)
elif approx == APPROX_DIAGONAL_NAME:
block = fb.FullyConnectedDiagonalFB(self, has_bias)
block.register_additional_minibatch(inputs, outputs)