Multi-minibatch support for
tf.contrib.kfac.fisher_blocks.FullyConnectedKFACBasicFB. PiperOrigin-RevId: 173109677
This commit is contained in:
parent
dc13a8e2f7
commit
670dddf4ad
tensorflow/contrib/kfac/python
@ -356,50 +356,51 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
inputs = array_ops.constant([1., 2.])
|
||||
outputs = array_ops.constant([3., 4.])
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), inputs,
|
||||
outputs)
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
self.assertAllEqual(outputs, block.tensors_to_compute_grads())
|
||||
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
|
||||
|
||||
def testInstantiateFactorsHasBias(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||
block = fb.FullyConnectedKFACBasicFB(
|
||||
lc.LayerCollection(), inputs, outputs, has_bias=True)
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
def testInstantiateFactorsNoBias(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||
block = fb.FullyConnectedKFACBasicFB(
|
||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
def testMultiplyInverseTuple(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
|
||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||
block = fb.FullyConnectedKFACBasicFB(
|
||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(block._input_factor.make_inverse_update_ops())
|
||||
sess.run(block._output_factor.make_inverse_update_ops())
|
||||
|
||||
vector = (np.arange(2, 6).reshape(2, 2).astype(np.float32), np.arange(
|
||||
1, 3).reshape(2, 1).astype(np.float32))
|
||||
vector = (
|
||||
np.arange(2, 6).reshape(2, 2).astype(np.float32), #
|
||||
np.arange(1, 3).reshape(2, 1).astype(np.float32))
|
||||
output = block.multiply_inverse((array_ops.constant(vector[0]),
|
||||
array_ops.constant(vector[1])))
|
||||
|
||||
@ -413,10 +414,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||
block = fb.FullyConnectedKFACBasicFB(
|
||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -436,11 +437,11 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
inputs = array_ops.zeros([32, input_dim])
|
||||
outputs = array_ops.zeros([32, output_dim])
|
||||
params = array_ops.zeros([input_dim, output_dim])
|
||||
block = fb.FullyConnectedKFACBasicFB(
|
||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
damping = 0. # This test is only valid without damping.
|
||||
block.instantiate_factors((grads,), damping)
|
||||
block.instantiate_factors(([grads],), damping)
|
||||
|
||||
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
|
||||
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
||||
|
@ -367,7 +367,7 @@ class ConvDiagonalFB(FisherBlock):
|
||||
(self._strides[1] * self._strides[2]))
|
||||
|
||||
if NORMALIZE_DAMPING_POWER:
|
||||
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
|
||||
damping /= self._num_locations ** NORMALIZE_DAMPING_POWER
|
||||
self._damping = damping
|
||||
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
@ -478,34 +478,60 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
K-FAC paper (https://arxiv.org/abs/1503.05671)
|
||||
"""
|
||||
|
||||
def __init__(self, layer_collection, inputs, outputs, has_bias=False):
|
||||
def __init__(self, layer_collection, has_bias=False):
|
||||
"""Creates a FullyConnectedKFACBasicFB block.
|
||||
|
||||
Args:
|
||||
layer_collection: The collection of all layers in the K-FAC approximate
|
||||
Fisher information matrix to which this FisherBlock belongs.
|
||||
inputs: The Tensor of input activations to this layer.
|
||||
outputs: The Tensor of output pre-activations from this layer.
|
||||
has_bias: Whether the component Kronecker factors have an additive bias.
|
||||
(Default: False)
|
||||
"""
|
||||
self._inputs = inputs
|
||||
self._outputs = outputs
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._has_bias = has_bias
|
||||
|
||||
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedKroneckerFactor,
|
||||
((self._inputs,), self._has_bias))
|
||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
|
||||
"""Instantiate Kronecker Factors for this FisherBlock.
|
||||
|
||||
Args:
|
||||
grads_list: List of list of Tensors. grads_list[i][j] is the
|
||||
gradient of the loss with respect to 'outputs' from source 'i' and
|
||||
tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
|
||||
damping: 0-D Tensor or float. 'damping' * identity is approximately added
|
||||
to this FisherBlock's Fisher approximation.
|
||||
"""
|
||||
# TODO(b/68033310): Validate which of,
|
||||
# (1) summing on a single device (as below), or
|
||||
# (2) on each device in isolation and aggregating
|
||||
# is faster.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor( #
|
||||
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||
((inputs,), self._has_bias))
|
||||
self._output_factor = self._layer_collection.make_or_get_factor( #
|
||||
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||
(grads_list,))
|
||||
self._register_damped_input_and_output_inverses(damping)
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
|
||||
matrix-multiply.
|
||||
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
|
||||
class ConvKFCBasicFB(KroneckerProductFB):
|
||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||
|
@ -573,6 +573,14 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
||||
"""
|
||||
|
||||
def __init__(self, tensors, has_bias=False):
|
||||
"""Instantiate FullyConnectedKroneckerFactor.
|
||||
|
||||
Args:
|
||||
tensors: List of Tensors of shape [batch_size, n]. Represents either a
|
||||
layer's inputs or its output's gradients.
|
||||
has_bias: bool. If True, assume this factor is for the layer's inputs and
|
||||
append '1' to each row.
|
||||
"""
|
||||
# The tensor argument is either a tensor of input activations or a tensor of
|
||||
# output pre-activation gradients.
|
||||
self._has_bias = has_bias
|
||||
|
@ -255,9 +255,9 @@ class LayerCollection(object):
|
||||
approx=APPROX_KRONECKER_NAME):
|
||||
has_bias = isinstance(params, (tuple, list))
|
||||
if approx == APPROX_KRONECKER_NAME:
|
||||
self.register_block(params,
|
||||
fb.FullyConnectedKFACBasicFB(self, inputs, outputs,
|
||||
has_bias))
|
||||
block = fb.FullyConnectedKFACBasicFB(self, has_bias)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.register_block(params, block)
|
||||
elif approx == APPROX_DIAGONAL_NAME:
|
||||
block = fb.FullyConnectedDiagonalFB(self, has_bias)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
Loading…
Reference in New Issue
Block a user