K-FAC: Multi-tower support for ConvKFCBasicFB
PiperOrigin-RevId: 173932013
This commit is contained in:
parent
1b6b7e208f
commit
b9337de5b3
@ -652,10 +652,10 @@ class ConvKFCBasicFBTest(test.TestCase):
|
|||||||
params = array_ops.constant(params)
|
params = array_ops.constant(params)
|
||||||
inputs = random_ops.random_normal((2, 2, 2))
|
inputs = random_ops.random_normal((2, 2, 2))
|
||||||
outputs = random_ops.random_normal((2, 2, 2))
|
outputs = random_ops.random_normal((2, 2, 2))
|
||||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
|
||||||
[1, 1, 1], 'SAME')
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
|
|
||||||
self.assertAllEqual(outputs, block.tensors_to_compute_grads())
|
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
|
||||||
|
|
||||||
def testConvKFCBasicFBInitParamsParamsTuple(self):
|
def testConvKFCBasicFBInitParamsParamsTuple(self):
|
||||||
self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
|
self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
|
||||||
@ -669,10 +669,11 @@ class ConvKFCBasicFBTest(test.TestCase):
|
|||||||
params = random_ops.random_normal((2, 2, 2, 2))
|
params = random_ops.random_normal((2, 2, 2, 2))
|
||||||
inputs = random_ops.random_normal((2, 2, 2, 2))
|
inputs = random_ops.random_normal((2, 2, 2, 2))
|
||||||
outputs = random_ops.random_normal((2, 2, 2, 2))
|
outputs = random_ops.random_normal((2, 2, 2, 2))
|
||||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||||
(1, 1, 1, 1), 'SAME')
|
'SAME')
|
||||||
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
# Make sure our inverse is something other than the identity.
|
# Make sure our inverse is something other than the identity.
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
@ -694,11 +695,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
|||||||
params = random_ops.random_normal((2, 2, 2, 2))
|
params = random_ops.random_normal((2, 2, 2, 2))
|
||||||
inputs = random_ops.random_normal((2, 2, 2, 2))
|
inputs = random_ops.random_normal((2, 2, 2, 2))
|
||||||
outputs = random_ops.random_normal((2, 2, 2, 2))
|
outputs = random_ops.random_normal((2, 2, 2, 2))
|
||||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||||
(1, 1, 1, 1), 'SAME')
|
'SAME')
|
||||||
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
self.assertFalse(block._has_bias)
|
self.assertFalse(block._has_bias)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
# Make sure our inverse is something other than the identity.
|
# Make sure our inverse is something other than the identity.
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
@ -716,11 +718,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
|||||||
params = [random_ops.random_normal((2, 2, 2, 2))]
|
params = [random_ops.random_normal((2, 2, 2, 2))]
|
||||||
inputs = random_ops.random_normal((2, 2, 2, 2))
|
inputs = random_ops.random_normal((2, 2, 2, 2))
|
||||||
outputs = random_ops.random_normal((2, 2, 2, 2))
|
outputs = random_ops.random_normal((2, 2, 2, 2))
|
||||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||||
(1, 1, 1, 1), 'SAME')
|
'SAME')
|
||||||
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
self.assertTrue(block._has_bias)
|
self.assertTrue(block._has_bias)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
block.instantiate_factors((grads,), 0.5)
|
block.instantiate_factors(([grads],), 0.5)
|
||||||
|
|
||||||
# Make sure our inverse is something other than the identity.
|
# Make sure our inverse is something other than the identity.
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
@ -738,11 +741,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
|||||||
params = array_ops.zeros((2, 2, 2, 2))
|
params = array_ops.zeros((2, 2, 2, 2))
|
||||||
inputs = array_ops.zeros((2, 2, 2, 2))
|
inputs = array_ops.zeros((2, 2, 2, 2))
|
||||||
outputs = array_ops.zeros((2, 2, 2, 2))
|
outputs = array_ops.zeros((2, 2, 2, 2))
|
||||||
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs,
|
block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
|
||||||
(1, 1, 1, 1), 'SAME')
|
'SAME')
|
||||||
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
grads = outputs**2
|
grads = outputs**2
|
||||||
damping = 0. # This test is only valid without damping.
|
damping = 0. # This test is only valid without damping.
|
||||||
block.instantiate_factors((grads,), damping)
|
block.instantiate_factors(([grads],), damping)
|
||||||
|
|
||||||
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
|
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
|
||||||
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
||||||
|
@ -454,6 +454,14 @@ class KroneckerProductFB(FisherBlock):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _renorm_coeff(self):
|
def _renorm_coeff(self):
|
||||||
|
"""Kronecker factor multiplier coefficient.
|
||||||
|
|
||||||
|
If this FisherBlock is represented as 'FB = c * kron(left, right)', then
|
||||||
|
this is 'c'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
0-D Tensor.
|
||||||
|
"""
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def multiply_inverse(self, vector):
|
def multiply_inverse(self, vector):
|
||||||
@ -560,17 +568,34 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_registered_minibatches(self):
|
def num_registered_minibatches(self):
|
||||||
return 1 # Multiple minibatches not supported.
|
return len(self._inputs)
|
||||||
|
|
||||||
|
|
||||||
class ConvKFCBasicFB(KroneckerProductFB):
|
class ConvKFCBasicFB(KroneckerProductFB):
|
||||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||||
|
|
||||||
See https://arxiv.org/abs/1602.01407 for details.
|
Estimates the Fisher Information matrix's blog for a convolutional
|
||||||
|
layer.
|
||||||
|
|
||||||
|
Consider a convoluational layer in this model with (unshared) filter matrix
|
||||||
|
'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
|
||||||
|
this FisherBlock estimates,
|
||||||
|
|
||||||
|
F(w) = #locations * kronecker(E[flat(a) flat(a)^T],
|
||||||
|
E[flat(ds) flat(ds)^T])
|
||||||
|
|
||||||
|
where
|
||||||
|
|
||||||
|
ds = (d / ds) log p(y | x, w)
|
||||||
|
#locations = number of (x, y) locations where 'w' is applied.
|
||||||
|
|
||||||
|
where the expectation is taken over all examples and locations and flat()
|
||||||
|
concatenates an array's leading dimensions.
|
||||||
|
|
||||||
|
See equation 23 in https://arxiv.org/abs/1602.01407 for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_collection, params, inputs, outputs, strides,
|
def __init__(self, layer_collection, params, strides, padding):
|
||||||
padding):
|
|
||||||
"""Creates a ConvKFCBasicFB block.
|
"""Creates a ConvKFCBasicFB block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -580,38 +605,43 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
|||||||
kernel alone, a Tensor of shape [kernel_height, kernel_width,
|
kernel alone, a Tensor of shape [kernel_height, kernel_width,
|
||||||
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
|
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
|
||||||
containing the previous and a Tensor of shape [out_channels].
|
containing the previous and a Tensor of shape [out_channels].
|
||||||
inputs: A Tensor of shape [batch_size, height, width, in_channels].
|
|
||||||
Input activations to this layer.
|
|
||||||
outputs: A Tensor of shape [batch_size, height, width, out_channels].
|
|
||||||
Output pre-activations from this layer.
|
|
||||||
strides: The stride size in this layer (1-D Tensor of length 4).
|
strides: The stride size in this layer (1-D Tensor of length 4).
|
||||||
padding: The padding in this layer (1-D of Tensor length 4).
|
padding: The padding in this layer (1-D of Tensor length 4).
|
||||||
"""
|
"""
|
||||||
self._inputs = inputs
|
self._inputs = []
|
||||||
self._outputs = outputs
|
self._outputs = []
|
||||||
self._strides = strides
|
self._strides = tuple(strides) if isinstance(strides, list) else strides
|
||||||
self._padding = padding
|
self._padding = padding
|
||||||
self._has_bias = isinstance(params, (tuple, list))
|
self._has_bias = isinstance(params, (tuple, list))
|
||||||
|
|
||||||
fltr = params[0] if self._has_bias else params
|
fltr = params[0] if self._has_bias else params
|
||||||
self._filter_shape = tuple(fltr.shape.as_list())
|
self._filter_shape = tuple(fltr.shape.as_list())
|
||||||
|
|
||||||
input_shape = tuple(inputs.shape.as_list())
|
|
||||||
self._num_locations = (
|
|
||||||
input_shape[1] * input_shape[2] // (strides[1] * strides[2]))
|
|
||||||
|
|
||||||
super(ConvKFCBasicFB, self).__init__(layer_collection)
|
super(ConvKFCBasicFB, self).__init__(layer_collection)
|
||||||
|
|
||||||
def instantiate_factors(self, grads_list, damping):
|
def instantiate_factors(self, grads_list, damping):
|
||||||
|
# TODO(b/68033310): Validate which of,
|
||||||
|
# (1) summing on a single device (as below), or
|
||||||
|
# (2) on each device in isolation and aggregating
|
||||||
|
# is faster.
|
||||||
|
inputs = _concat_along_batch_dim(self._inputs)
|
||||||
|
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||||
|
|
||||||
|
# Infer number of locations upon which convolution is applied.
|
||||||
|
self._num_locations = _num_conv_locations(inputs.shape.as_list(),
|
||||||
|
self._strides)
|
||||||
|
|
||||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
self._input_factor = self._layer_collection.make_or_get_factor(
|
||||||
fisher_factors.ConvInputKroneckerFactor,
|
fisher_factors.ConvInputKroneckerFactor,
|
||||||
(self._inputs, self._filter_shape, self._strides, self._padding,
|
(inputs, self._filter_shape, self._strides, self._padding,
|
||||||
self._has_bias))
|
self._has_bias))
|
||||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
self._output_factor = self._layer_collection.make_or_get_factor(
|
||||||
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
|
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
|
||||||
|
|
||||||
if NORMALIZE_DAMPING_POWER:
|
if NORMALIZE_DAMPING_POWER:
|
||||||
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
|
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
|
||||||
|
self._damping = damping
|
||||||
|
|
||||||
self._register_damped_input_and_output_inverses(damping)
|
self._register_damped_input_and_output_inverses(damping)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -621,9 +651,21 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
|||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._outputs
|
return self._outputs
|
||||||
|
|
||||||
|
def register_additional_minibatch(self, inputs, outputs):
|
||||||
|
"""Registers an additional minibatch to the FisherBlock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
|
||||||
|
the convolution.
|
||||||
|
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
|
||||||
|
preactivations.
|
||||||
|
"""
|
||||||
|
self._inputs.append(inputs)
|
||||||
|
self._outputs.append(outputs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_registered_minibatches(self):
|
def num_registered_minibatches(self):
|
||||||
return 1 # Multiple minibatches not supported.
|
return len(self._inputs)
|
||||||
|
|
||||||
|
|
||||||
def _concat_along_batch_dim(tensor_list):
|
def _concat_along_batch_dim(tensor_list):
|
||||||
@ -651,3 +693,8 @@ def _concat_along_batch_dim(tensor_list):
|
|||||||
else:
|
else:
|
||||||
# [tensor1, tensor2] --> tensor
|
# [tensor1, tensor2] --> tensor
|
||||||
return array_ops.concat(tensor_list, axis=0)
|
return array_ops.concat(tensor_list, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _num_conv_locations(input_shape, strides):
|
||||||
|
"""Returns the number of locations a Conv kernel is applied to."""
|
||||||
|
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
|
||||||
|
@ -609,9 +609,28 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
|||||||
|
|
||||||
|
|
||||||
class ConvInputKroneckerFactor(InverseProvidingFactor):
|
class ConvInputKroneckerFactor(InverseProvidingFactor):
|
||||||
"""Kronecker factor for the input side of a convolutional layer."""
|
r"""Kronecker factor for the input side of a convolutional layer.
|
||||||
|
|
||||||
|
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
|
||||||
|
example x. Expectation is taken over all examples and locations.
|
||||||
|
|
||||||
|
Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See
|
||||||
|
Section 3.1 Estimating the factors.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, inputs, filter_shape, strides, padding, has_bias=False):
|
def __init__(self, inputs, filter_shape, strides, padding, has_bias=False):
|
||||||
|
"""Initializes ConvInputKroneckerFactor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
|
||||||
|
to layer.
|
||||||
|
filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
|
||||||
|
kernel_width, in_channels, out_channels].
|
||||||
|
strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
|
||||||
|
width_stride, in_channel_stride].
|
||||||
|
padding: str. Padding method for layer. "SAME" or "VALID".
|
||||||
|
has_bias: bool. If True, append 1 to in_channel.
|
||||||
|
"""
|
||||||
self._filter_shape = filter_shape
|
self._filter_shape = filter_shape
|
||||||
self._strides = strides
|
self._strides = strides
|
||||||
self._padding = padding
|
self._padding = padding
|
||||||
@ -659,9 +678,23 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
|
|||||||
|
|
||||||
|
|
||||||
class ConvOutputKroneckerFactor(InverseProvidingFactor):
|
class ConvOutputKroneckerFactor(InverseProvidingFactor):
|
||||||
"""Kronecker factor for the output side of a convolutional layer."""
|
r"""Kronecker factor for the output side of a convolutional layer.
|
||||||
|
|
||||||
|
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
|
||||||
|
given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
|
||||||
|
all examples and locations.
|
||||||
|
|
||||||
|
Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See
|
||||||
|
Section 3.1 Estimating the factors.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, outputs_grads):
|
def __init__(self, outputs_grads):
|
||||||
|
"""Initializes ConvOutputKroneckerFactor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs_grads: list of Tensors. Each Tensor is of shape
|
||||||
|
[batch_size, height, width, out_channels].
|
||||||
|
"""
|
||||||
self._out_channels = outputs_grads[0].shape.as_list()[3]
|
self._out_channels = outputs_grads[0].shape.as_list()[3]
|
||||||
self._outputs_grads = outputs_grads
|
self._outputs_grads = outputs_grads
|
||||||
super(ConvOutputKroneckerFactor, self).__init__()
|
super(ConvOutputKroneckerFactor, self).__init__()
|
||||||
|
@ -315,9 +315,9 @@ class LayerCollection(object):
|
|||||||
approx=APPROX_KRONECKER_NAME):
|
approx=APPROX_KRONECKER_NAME):
|
||||||
|
|
||||||
if approx == APPROX_KRONECKER_NAME:
|
if approx == APPROX_KRONECKER_NAME:
|
||||||
self.register_block(params,
|
block = fb.ConvKFCBasicFB(self, params, strides, padding)
|
||||||
fb.ConvKFCBasicFB(self, params, inputs, outputs,
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
strides, padding))
|
self.register_block(params, block)
|
||||||
elif approx == APPROX_DIAGONAL_NAME:
|
elif approx == APPROX_DIAGONAL_NAME:
|
||||||
block = fb.ConvDiagonalFB(self, params, strides, padding)
|
block = fb.ConvDiagonalFB(self, params, strides, padding)
|
||||||
block.register_additional_minibatch(inputs, outputs)
|
block.register_additional_minibatch(inputs, outputs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user