Multi-minibatch support for
tf.contrib.kfac.fisher_blocks.FullyConnectedKFACBasicFB. PiperOrigin-RevId: 173109677
This commit is contained in:
parent
dc13a8e2f7
commit
670dddf4ad
@ -356,50 +356,51 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
|||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
inputs = array_ops.constant([1., 2.])
|
inputs = array_ops.constant([1., 2.])
|
||||||
outputs = array_ops.constant([3., 4.])
|
outputs = array_ops.constant([3., 4.])
|
||||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), inputs,
|
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
|
||||||
outputs)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
|
|
||||||
self.assertAllEqual(outputs, block.tensors_to_compute_grads())
|
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
|
||||||
|
|
||||||
def testInstantiateFactorsHasBias(self):
|
def testInstantiateFactorsHasBias(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
||||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||||
block = fb.FullyConnectedKFACBasicFB(
|
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
|
||||||
lc.LayerCollection(), inputs, outputs, has_bias=True)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
|
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
def testInstantiateFactorsNoBias(self):
|
def testInstantiateFactorsNoBias(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
||||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||||
block = fb.FullyConnectedKFACBasicFB(
|
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
|
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
def testMultiplyInverseTuple(self):
|
def testMultiplyInverseTuple(self):
|
||||||
with ops.Graph().as_default(), self.test_session() as sess:
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
|
inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
|
||||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||||
block = fb.FullyConnectedKFACBasicFB(
|
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
# Make sure our inverse is something other than the identity.
|
# Make sure our inverse is something other than the identity.
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
sess.run(block._input_factor.make_inverse_update_ops())
|
sess.run(block._input_factor.make_inverse_update_ops())
|
||||||
sess.run(block._output_factor.make_inverse_update_ops())
|
sess.run(block._output_factor.make_inverse_update_ops())
|
||||||
|
|
||||||
vector = (np.arange(2, 6).reshape(2, 2).astype(np.float32), np.arange(
|
vector = (
|
||||||
1, 3).reshape(2, 1).astype(np.float32))
|
np.arange(2, 6).reshape(2, 2).astype(np.float32), #
|
||||||
|
np.arange(1, 3).reshape(2, 1).astype(np.float32))
|
||||||
output = block.multiply_inverse((array_ops.constant(vector[0]),
|
output = block.multiply_inverse((array_ops.constant(vector[0]),
|
||||||
array_ops.constant(vector[1])))
|
array_ops.constant(vector[1])))
|
||||||
|
|
||||||
@ -413,10 +414,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
|||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
inputs = array_ops.constant([[1., 2.], [3., 4.]])
|
||||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||||
block = fb.FullyConnectedKFACBasicFB(
|
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
# Make sure our inverse is something other than the identity.
|
# Make sure our inverse is something other than the identity.
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
@ -436,11 +437,11 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
|||||||
inputs = array_ops.zeros([32, input_dim])
|
inputs = array_ops.zeros([32, input_dim])
|
||||||
outputs = array_ops.zeros([32, output_dim])
|
outputs = array_ops.zeros([32, output_dim])
|
||||||
params = array_ops.zeros([input_dim, output_dim])
|
params = array_ops.zeros([input_dim, output_dim])
|
||||||
block = fb.FullyConnectedKFACBasicFB(
|
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||||
lc.LayerCollection(), inputs, outputs, has_bias=False)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
damping = 0. # This test is only valid without damping.
|
damping = 0. # This test is only valid without damping.
|
||||||
block.instantiate_factors((grads,), damping)
|
block.instantiate_factors(([grads],), damping)
|
||||||
|
|
||||||
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
|
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
|
||||||
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
||||||
|
@ -367,7 +367,7 @@ class ConvDiagonalFB(FisherBlock):
|
|||||||
(self._strides[1] * self._strides[2]))
|
(self._strides[1] * self._strides[2]))
|
||||||
|
|
||||||
if NORMALIZE_DAMPING_POWER:
|
if NORMALIZE_DAMPING_POWER:
|
||||||
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
|
damping /= self._num_locations ** NORMALIZE_DAMPING_POWER
|
||||||
self._damping = damping
|
self._damping = damping
|
||||||
|
|
||||||
self._factor = self._layer_collection.make_or_get_factor(
|
self._factor = self._layer_collection.make_or_get_factor(
|
||||||
@ -478,34 +478,60 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
|||||||
K-FAC paper (https://arxiv.org/abs/1503.05671)
|
K-FAC paper (https://arxiv.org/abs/1503.05671)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_collection, inputs, outputs, has_bias=False):
|
def __init__(self, layer_collection, has_bias=False):
|
||||||
"""Creates a FullyConnectedKFACBasicFB block.
|
"""Creates a FullyConnectedKFACBasicFB block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer_collection: The collection of all layers in the K-FAC approximate
|
layer_collection: The collection of all layers in the K-FAC approximate
|
||||||
Fisher information matrix to which this FisherBlock belongs.
|
Fisher information matrix to which this FisherBlock belongs.
|
||||||
inputs: The Tensor of input activations to this layer.
|
|
||||||
outputs: The Tensor of output pre-activations from this layer.
|
|
||||||
has_bias: Whether the component Kronecker factors have an additive bias.
|
has_bias: Whether the component Kronecker factors have an additive bias.
|
||||||
(Default: False)
|
(Default: False)
|
||||||
"""
|
"""
|
||||||
self._inputs = inputs
|
self._inputs = []
|
||||||
self._outputs = outputs
|
self._outputs = []
|
||||||
self._has_bias = has_bias
|
self._has_bias = has_bias
|
||||||
|
|
||||||
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
|
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
|
||||||
|
|
||||||
def instantiate_factors(self, grads_list, damping):
|
def instantiate_factors(self, grads_list, damping):
|
||||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
"""Instantiate Kronecker Factors for this FisherBlock.
|
||||||
fisher_factors.FullyConnectedKroneckerFactor,
|
|
||||||
((self._inputs,), self._has_bias))
|
Args:
|
||||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
grads_list: List of list of Tensors. grads_list[i][j] is the
|
||||||
fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
|
gradient of the loss with respect to 'outputs' from source 'i' and
|
||||||
|
tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
|
||||||
|
damping: 0-D Tensor or float. 'damping' * identity is approximately added
|
||||||
|
to this FisherBlock's Fisher approximation.
|
||||||
|
"""
|
||||||
|
# TODO(b/68033310): Validate which of,
|
||||||
|
# (1) summing on a single device (as below), or
|
||||||
|
# (2) on each device in isolation and aggregating
|
||||||
|
# is faster.
|
||||||
|
inputs = _concat_along_batch_dim(self._inputs)
|
||||||
|
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||||
|
|
||||||
|
self._input_factor = self._layer_collection.make_or_get_factor( #
|
||||||
|
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||||
|
((inputs,), self._has_bias))
|
||||||
|
self._output_factor = self._layer_collection.make_or_get_factor( #
|
||||||
|
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||||
|
(grads_list,))
|
||||||
self._register_damped_input_and_output_inverses(damping)
|
self._register_damped_input_and_output_inverses(damping)
|
||||||
|
|
||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._outputs
|
return self._outputs
|
||||||
|
|
||||||
|
def register_additional_minibatch(self, inputs, outputs):
|
||||||
|
"""Registers an additional minibatch to the FisherBlock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
|
||||||
|
matrix-multiply.
|
||||||
|
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
|
||||||
|
"""
|
||||||
|
self._inputs.append(inputs)
|
||||||
|
self._outputs.append(outputs)
|
||||||
|
|
||||||
|
|
||||||
class ConvKFCBasicFB(KroneckerProductFB):
|
class ConvKFCBasicFB(KroneckerProductFB):
|
||||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||||
|
@ -573,6 +573,14 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tensors, has_bias=False):
|
def __init__(self, tensors, has_bias=False):
|
||||||
|
"""Instantiate FullyConnectedKroneckerFactor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: List of Tensors of shape [batch_size, n]. Represents either a
|
||||||
|
layer's inputs or its output's gradients.
|
||||||
|
has_bias: bool. If True, assume this factor is for the layer's inputs and
|
||||||
|
append '1' to each row.
|
||||||
|
"""
|
||||||
# The tensor argument is either a tensor of input activations or a tensor of
|
# The tensor argument is either a tensor of input activations or a tensor of
|
||||||
# output pre-activation gradients.
|
# output pre-activation gradients.
|
||||||
self._has_bias = has_bias
|
self._has_bias = has_bias
|
||||||
|
@ -255,9 +255,9 @@ class LayerCollection(object):
|
|||||||
approx=APPROX_KRONECKER_NAME):
|
approx=APPROX_KRONECKER_NAME):
|
||||||
has_bias = isinstance(params, (tuple, list))
|
has_bias = isinstance(params, (tuple, list))
|
||||||
if approx == APPROX_KRONECKER_NAME:
|
if approx == APPROX_KRONECKER_NAME:
|
||||||
self.register_block(params,
|
block = fb.FullyConnectedKFACBasicFB(self, has_bias)
|
||||||
fb.FullyConnectedKFACBasicFB(self, inputs, outputs,
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
has_bias))
|
self.register_block(params, block)
|
||||||
elif approx == APPROX_DIAGONAL_NAME:
|
elif approx == APPROX_DIAGONAL_NAME:
|
||||||
block = fb.FullyConnectedDiagonalFB(self, has_bias)
|
block = fb.FullyConnectedDiagonalFB(self, has_bias)
|
||||||
block.register_additional_minibatch(inputs, outputs)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user