K-FAC: Multi-tower support for ConvKFCBasicFB

PiperOrigin-RevId: 173932013
This commit is contained in:
A. Unique TensorFlower 2017-10-30 12:29:16 -07:00 committed by TensorFlower Gardener
parent 1b6b7e208f
commit b9337de5b3
4 changed files with 121 additions and 37 deletions

View File

@ -652,10 +652,10 @@ class ConvKFCBasicFBTest(test.TestCase):
params = array_ops.constant(params)
inputs = random_ops.random_normal((2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2))
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
[1, 1, 1], 'SAME')
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertAllEqual(outputs, block.tensors_to_compute_grads())
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
def testConvKFCBasicFBInitParamsParamsTuple(self):
self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
@ -669,10 +669,11 @@ class ConvKFCBasicFBTest(test.TestCase):
params = random_ops.random_normal((2, 2, 2, 2))
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
(1, 1, 1, 1), 'SAME')
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
'SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -694,11 +695,12 @@ class ConvKFCBasicFBTest(test.TestCase):
params = random_ops.random_normal((2, 2, 2, 2))
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
(1, 1, 1, 1), 'SAME')
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
'SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertFalse(block._has_bias)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -716,11 +718,12 @@ class ConvKFCBasicFBTest(test.TestCase):
params = [random_ops.random_normal((2, 2, 2, 2))]
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
(1, 1, 1, 1), 'SAME')
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
'SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertTrue(block._has_bias)
grads = outputs**2
block.instantiate_factors((grads,), 0.5)
block.instantiate_factors(([grads],), 0.5)
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -738,11 +741,12 @@ class ConvKFCBasicFBTest(test.TestCase):
params = array_ops.zeros((2, 2, 2, 2))
inputs = array_ops.zeros((2, 2, 2, 2))
outputs = array_ops.zeros((2, 2, 2, 2))
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
(1, 1, 1, 1), 'SAME')
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
'SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
block.instantiate_factors((grads,), damping)
block.instantiate_factors(([grads],), damping)
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))

View File

@ -454,6 +454,14 @@ class KroneckerProductFB(FisherBlock):
@property
def _renorm_coeff(self):
"""Kronecker factor multiplier coefficient.
If this FisherBlock is represented as 'FB = c * kron(left, right)', then
this is 'c'.
Returns:
0-D Tensor.
"""
return 1.0
def multiply_inverse(self, vector):
@ -560,17 +568,34 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
return len(self._inputs)
class ConvKFCBasicFB(KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
See https://arxiv.org/abs/1602.01407 for details.
Estimates the Fisher Information matrix's blog for a convolutional
layer.
Consider a convoluational layer in this model with (unshared) filter matrix
'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
this FisherBlock estimates,
F(w) = #locations * kronecker(E[flat(a) flat(a)^T],
E[flat(ds) flat(ds)^T])
where
ds = (d / ds) log p(y | x, w)
#locations = number of (x, y) locations where 'w' is applied.
where the expectation is taken over all examples and locations and flat()
concatenates an array's leading dimensions.
See equation 23 in https://arxiv.org/abs/1602.01407 for details.
"""
def __init__(self, layer_collection, params, inputs, outputs, strides,
padding):
def __init__(self, layer_collection, params, strides, padding):
"""Creates a ConvKFCBasicFB block.
Args:
@ -580,38 +605,43 @@ class ConvKFCBasicFB(KroneckerProductFB):
kernel alone, a Tensor of shape [kernel_height, kernel_width,
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
containing the previous and a Tensor of shape [out_channels].
inputs: A Tensor of shape [batch_size, height, width, in_channels].
Input activations to this layer.
outputs: A Tensor of shape [batch_size, height, width, out_channels].
Output pre-activations from this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
"""
self._inputs = inputs
self._outputs = outputs
self._strides = strides
self._inputs = []
self._outputs = []
self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
self._has_bias = isinstance(params, (tuple, list))
fltr = params[0] if self._has_bias else params
self._filter_shape = tuple(fltr.shape.as_list())
input_shape = tuple(inputs.shape.as_list())
self._num_locations = (
input_shape[1] * input_shape[2] // (strides[1] * strides[2]))
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
# TODO(b/68033310): Validate which of,
# (1) summing on a single device (as below), or
# (2) on each device in isolation and aggregating
# is faster.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
# Infer number of locations upon which convolution is applied.
self._num_locations = _num_conv_locations(inputs.shape.as_list(),
self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(self._inputs, self._filter_shape, self._strides, self._padding,
(inputs, self._filter_shape, self._strides, self._padding,
self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
if NORMALIZE_DAMPING_POWER:
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
self._damping = damping
self._register_damped_input_and_output_inverses(damping)
@property
@ -621,9 +651,21 @@ class ConvKFCBasicFB(KroneckerProductFB):
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
the convolution.
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
return len(self._inputs)
def _concat_along_batch_dim(tensor_list):
@ -651,3 +693,8 @@ def _concat_along_batch_dim(tensor_list):
else:
# [tensor1, tensor2] --> tensor
return array_ops.concat(tensor_list, axis=0)
def _num_conv_locations(input_shape, strides):
"""Returns the number of locations a Conv kernel is applied to."""
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])

View File

@ -609,9 +609,28 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
class ConvInputKroneckerFactor(InverseProvidingFactor):
"""Kronecker factor for the input side of a convolutional layer."""
r"""Kronecker factor for the input side of a convolutional layer.
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
example x. Expectation is taken over all examples and locations.
Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See
Section 3.1 Estimating the factors.
"""
def __init__(self, inputs, filter_shape, strides, padding, has_bias=False):
"""Initializes ConvInputKroneckerFactor.
Args:
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
to layer.
filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
kernel_width, in_channels, out_channels].
strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
width_stride, in_channel_stride].
padding: str. Padding method for layer. "SAME" or "VALID".
has_bias: bool. If True, append 1 to in_channel.
"""
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
@ -659,9 +678,23 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
class ConvOutputKroneckerFactor(InverseProvidingFactor):
"""Kronecker factor for the output side of a convolutional layer."""
r"""Kronecker factor for the output side of a convolutional layer.
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
all examples and locations.
Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See
Section 3.1 Estimating the factors.
"""
def __init__(self, outputs_grads):
"""Initializes ConvOutputKroneckerFactor.
Args:
outputs_grads: list of Tensors. Each Tensor is of shape
[batch_size, height, width, out_channels].
"""
self._out_channels = outputs_grads[0].shape.as_list()[3]
self._outputs_grads = outputs_grads
super(ConvOutputKroneckerFactor, self).__init__()

View File

@ -315,9 +315,9 @@ class LayerCollection(object):
approx=APPROX_KRONECKER_NAME):
if approx == APPROX_KRONECKER_NAME:
self.register_block(params,
fb.ConvKFCBasicFB(self, params, inputs, outputs,
strides, padding))
block = fb.ConvKFCBasicFB(self, params, strides, padding)
block.register_additional_minibatch(inputs, outputs)
self.register_block(params, block)
elif approx == APPROX_DIAGONAL_NAME:
block = fb.ConvDiagonalFB(self, params, strides, padding)
block.register_additional_minibatch(inputs, outputs)