K-FAC: Multi-tower support for ConvKFCBasicFB
PiperOrigin-RevId: 173932013
This commit is contained in:
parent
1b6b7e208f
commit
b9337de5b3
@ -652,10 +652,10 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
params = array_ops.constant(params)
|
||||
inputs = random_ops.random_normal((2, 2, 2))
|
||||
outputs = random_ops.random_normal((2, 2, 2))
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
||||
[1, 1, 1], 'SAME')
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
self.assertAllEqual(outputs, block.tensors_to_compute_grads())
|
||||
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
|
||||
|
||||
def testConvKFCBasicFBInitParamsParamsTuple(self):
|
||||
self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
|
||||
@ -669,10 +669,11 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
params = random_ops.random_normal((2, 2, 2, 2))
|
||||
inputs = random_ops.random_normal((2, 2, 2, 2))
|
||||
outputs = random_ops.random_normal((2, 2, 2, 2))
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
||||
(1, 1, 1, 1), 'SAME')
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||
'SAME')
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -694,11 +695,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
params = random_ops.random_normal((2, 2, 2, 2))
|
||||
inputs = random_ops.random_normal((2, 2, 2, 2))
|
||||
outputs = random_ops.random_normal((2, 2, 2, 2))
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
||||
(1, 1, 1, 1), 'SAME')
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||
'SAME')
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.assertFalse(block._has_bias)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -716,11 +718,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
params = [random_ops.random_normal((2, 2, 2, 2))]
|
||||
inputs = random_ops.random_normal((2, 2, 2, 2))
|
||||
outputs = random_ops.random_normal((2, 2, 2, 2))
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
||||
(1, 1, 1, 1), 'SAME')
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||
'SAME')
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.assertTrue(block._has_bias)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -738,11 +741,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
params = array_ops.zeros((2, 2, 2, 2))
|
||||
inputs = array_ops.zeros((2, 2, 2, 2))
|
||||
outputs = array_ops.zeros((2, 2, 2, 2))
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
||||
(1, 1, 1, 1), 'SAME')
|
||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||
'SAME')
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
damping = 0. # This test is only valid without damping.
|
||||
block.instantiate_factors((grads,), damping)
|
||||
block.instantiate_factors(([grads],), damping)
|
||||
|
||||
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
|
||||
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
||||
|
@ -454,6 +454,14 @@ class KroneckerProductFB(FisherBlock):
|
||||
|
||||
@property
|
||||
def _renorm_coeff(self):
|
||||
"""Kronecker factor multiplier coefficient.
|
||||
|
||||
If this FisherBlock is represented as 'FB = c * kron(left, right)', then
|
||||
this is 'c'.
|
||||
|
||||
Returns:
|
||||
0-D Tensor.
|
||||
"""
|
||||
return 1.0
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
@ -560,17 +568,34 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
return len(self._inputs)
|
||||
|
||||
|
||||
class ConvKFCBasicFB(KroneckerProductFB):
|
||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||
|
||||
See https://arxiv.org/abs/1602.01407 for details.
|
||||
Estimates the Fisher Information matrix's blog for a convolutional
|
||||
layer.
|
||||
|
||||
Consider a convoluational layer in this model with (unshared) filter matrix
|
||||
'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
|
||||
this FisherBlock estimates,
|
||||
|
||||
F(w) = #locations * kronecker(E[flat(a) flat(a)^T],
|
||||
E[flat(ds) flat(ds)^T])
|
||||
|
||||
where
|
||||
|
||||
ds = (d / ds) log p(y | x, w)
|
||||
#locations = number of (x, y) locations where 'w' is applied.
|
||||
|
||||
where the expectation is taken over all examples and locations and flat()
|
||||
concatenates an array's leading dimensions.
|
||||
|
||||
See equation 23 in https://arxiv.org/abs/1602.01407 for details.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_collection, params, inputs, outputs, strides,
|
||||
padding):
|
||||
def __init__(self, layer_collection, params, strides, padding):
|
||||
"""Creates a ConvKFCBasicFB block.
|
||||
|
||||
Args:
|
||||
@ -580,38 +605,43 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
||||
kernel alone, a Tensor of shape [kernel_height, kernel_width,
|
||||
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
|
||||
containing the previous and a Tensor of shape [out_channels].
|
||||
inputs: A Tensor of shape [batch_size, height, width, in_channels].
|
||||
Input activations to this layer.
|
||||
outputs: A Tensor of shape [batch_size, height, width, out_channels].
|
||||
Output pre-activations from this layer.
|
||||
strides: The stride size in this layer (1-D Tensor of length 4).
|
||||
padding: The padding in this layer (1-D of Tensor length 4).
|
||||
"""
|
||||
self._inputs = inputs
|
||||
self._outputs = outputs
|
||||
self._strides = strides
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._strides = tuple(strides) if isinstance(strides, list) else strides
|
||||
self._padding = padding
|
||||
self._has_bias = isinstance(params, (tuple, list))
|
||||
|
||||
fltr = params[0] if self._has_bias else params
|
||||
self._filter_shape = tuple(fltr.shape.as_list())
|
||||
|
||||
input_shape = tuple(inputs.shape.as_list())
|
||||
self._num_locations = (
|
||||
input_shape[1] * input_shape[2] // (strides[1] * strides[2]))
|
||||
|
||||
super(ConvKFCBasicFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
# TODO(b/68033310): Validate which of,
|
||||
# (1) summing on a single device (as below), or
|
||||
# (2) on each device in isolation and aggregating
|
||||
# is faster.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
|
||||
# Infer number of locations upon which convolution is applied.
|
||||
self._num_locations = _num_conv_locations(inputs.shape.as_list(),
|
||||
self._strides)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.ConvInputKroneckerFactor,
|
||||
(self._inputs, self._filter_shape, self._strides, self._padding,
|
||||
(inputs, self._filter_shape, self._strides, self._padding,
|
||||
self._has_bias))
|
||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
|
||||
|
||||
if NORMALIZE_DAMPING_POWER:
|
||||
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
|
||||
self._damping = damping
|
||||
|
||||
self._register_damped_input_and_output_inverses(damping)
|
||||
|
||||
@property
|
||||
@ -621,9 +651,21 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
|
||||
the convolution.
|
||||
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
|
||||
preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
return len(self._inputs)
|
||||
|
||||
|
||||
def _concat_along_batch_dim(tensor_list):
|
||||
@ -651,3 +693,8 @@ def _concat_along_batch_dim(tensor_list):
|
||||
else:
|
||||
# [tensor1, tensor2] --> tensor
|
||||
return array_ops.concat(tensor_list, axis=0)
|
||||
|
||||
|
||||
def _num_conv_locations(input_shape, strides):
|
||||
"""Returns the number of locations a Conv kernel is applied to."""
|
||||
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
|
||||
|
@ -609,9 +609,28 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
||||
|
||||
|
||||
class ConvInputKroneckerFactor(InverseProvidingFactor):
|
||||
"""Kronecker factor for the input side of a convolutional layer."""
|
||||
r"""Kronecker factor for the input side of a convolutional layer.
|
||||
|
||||
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
|
||||
example x. Expectation is taken over all examples and locations.
|
||||
|
||||
Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See
|
||||
Section 3.1 Estimating the factors.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, filter_shape, strides, padding, has_bias=False):
|
||||
"""Initializes ConvInputKroneckerFactor.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
|
||||
to layer.
|
||||
filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
|
||||
kernel_width, in_channels, out_channels].
|
||||
strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
|
||||
width_stride, in_channel_stride].
|
||||
padding: str. Padding method for layer. "SAME" or "VALID".
|
||||
has_bias: bool. If True, append 1 to in_channel.
|
||||
"""
|
||||
self._filter_shape = filter_shape
|
||||
self._strides = strides
|
||||
self._padding = padding
|
||||
@ -659,9 +678,23 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
|
||||
|
||||
|
||||
class ConvOutputKroneckerFactor(InverseProvidingFactor):
|
||||
"""Kronecker factor for the output side of a convolutional layer."""
|
||||
r"""Kronecker factor for the output side of a convolutional layer.
|
||||
|
||||
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
|
||||
given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
|
||||
all examples and locations.
|
||||
|
||||
Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See
|
||||
Section 3.1 Estimating the factors.
|
||||
"""
|
||||
|
||||
def __init__(self, outputs_grads):
|
||||
"""Initializes ConvOutputKroneckerFactor.
|
||||
|
||||
Args:
|
||||
outputs_grads: list of Tensors. Each Tensor is of shape
|
||||
[batch_size, height, width, out_channels].
|
||||
"""
|
||||
self._out_channels = outputs_grads[0].shape.as_list()[3]
|
||||
self._outputs_grads = outputs_grads
|
||||
super(ConvOutputKroneckerFactor, self).__init__()
|
||||
|
@ -315,9 +315,9 @@ class LayerCollection(object):
|
||||
approx=APPROX_KRONECKER_NAME):
|
||||
|
||||
if approx == APPROX_KRONECKER_NAME:
|
||||
self.register_block(params,
|
||||
fb.ConvKFCBasicFB(self, params, inputs, outputs,
|
||||
strides, padding))
|
||||
block = fb.ConvKFCBasicFB(self, params, strides, padding)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.register_block(params, block)
|
||||
elif approx == APPROX_DIAGONAL_NAME:
|
||||
block = fb.ConvDiagonalFB(self, params, strides, padding)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
Loading…
Reference in New Issue
Block a user