K-FAC: Multi-tower support for ConvDiagonalFB.

PiperOrigin-RevId: 173105412
This commit is contained in:
A. Unique TensorFlower 2017-10-23 06:00:06 -07:00 committed by TensorFlower Gardener
parent fd8d517b97
commit eea089bdb6
4 changed files with 271 additions and 34 deletions

View File

@ -328,17 +328,11 @@ class FullyConnectedDiagonalFB(test.TestCase):
multiply_result: Result of FisherBlock.multiply(params)
multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
"""
def _as_tensors(tensor_or_tuple):
if isinstance(tensor_or_tuple, (tuple, list)):
return tuple(ops.convert_to_tensor(t) for t in tensor_or_tuple)
return ops.convert_to_tensor(tensor_or_tuple)
with ops.Graph().as_default(), self.test_session() as sess:
inputs = [_as_tensors(i) for i in inputs]
outputs = [_as_tensors(o) for o in outputs]
output_grads = [_as_tensors(og) for og in output_grads]
params = _as_tensors(params)
inputs = as_tensors(inputs)
outputs = as_tensors(outputs)
output_grads = as_tensors(output_grads)
params = as_tensors(params)
block = fb.FullyConnectedDiagonalFB(
lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
@ -464,6 +458,188 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
self.assertAllClose(output_flat, explicit)
class ConvDiagonalFBTest(test.TestCase):
def setUp(self):
super(ConvDiagonalFBTest, self).setUp()
self.batch_size = 2
self.height = 8
self.width = 4
self.input_channels = 6
self.output_channels = 3
self.kernel_size = 1
self.inputs = np.random.randn(self.batch_size, self.height, self.width,
self.input_channels).astype(np.float32)
self.outputs = np.zeros(
[self.batch_size, self.height, self.width,
self.output_channels]).astype(np.float32)
self.output_grads = np.random.randn(
self.batch_size, self.height, self.width, self.output_channels).astype(
np.float32)
self.w = np.random.randn(self.kernel_size, self.kernel_size,
self.input_channels, self.output_channels).astype(
np.float32)
self.b = np.random.randn(self.output_channels).astype(np.float32)
def fisherApprox(self, has_bias=False):
"""Fisher approximation using default inputs."""
if has_bias:
inputs = np.concatenate(
[self.inputs,
np.ones([self.batch_size, self.height, self.width, 1])],
axis=-1)
else:
inputs = self.inputs
return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
self.kernel_size)
def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
r"""Builds explicit diagonal Fisher approximation.
Fisher's diagonal is (d loss / d w)'s elements squared for
d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
where the expectation is taken over examples and the sum over (x, y)
locations upon which the convolution is applied.
Args:
inputs: np.array of shape [batch_size, height, width, input_channels].
output_grads: np.array of shape [batch_size, height, width,
output_channels].
kernel_size: int. height and width of kernel.
Returns:
Diagonal np.array of shape [num_params, num_params] for num_params =
kernel_size^2 * input_channels * output_channels.
"""
batch_size, height, width, input_channels = inputs.shape
assert output_grads.shape[0] == batch_size
assert output_grads.shape[1] == height
assert output_grads.shape[2] == width
output_channels = output_grads.shape[3]
# If kernel_size == 1, then we don't need to worry about capturing context
# around the pixel upon which a convolution is applied. This makes testing
# easier.
assert kernel_size == 1, "kernel_size != 1 isn't supported."
num_locations = height * width
inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
output_grads = np.reshape(output_grads,
[batch_size, num_locations, output_channels])
fisher_diag = np.zeros((input_channels, output_channels))
for i in range(batch_size):
# Each example's approximation is a square(sum-of-outer-products).
example_fisher_diag = np.zeros((input_channels, output_channels))
for j in range(num_locations):
example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
fisher_diag += np.square(example_fisher_diag)
# Normalize by batch_size (not num_locations).
return np.diag(fisher_diag.flatten()) / batch_size
def testMultiply(self):
result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
[self.output_grads])
# Construct Fisher-vector product.
expected_result = self.fisherApprox().dot(self.w.flatten())
expected_result = expected_result.reshape([
self.kernel_size, self.kernel_size, self.input_channels,
self.output_channels
])
self.assertAllClose(expected_result, result)
def testMultiplyInverse(self):
_, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
[self.output_grads])
# Construct inverse Fisher-vector product.
expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
expected_result = expected_result.reshape([
self.kernel_size, self.kernel_size, self.input_channels,
self.output_channels
])
self.assertAllClose(expected_result, result, atol=1e-3)
def testRegisterAdditionalMinibatch(self):
"""Ensure 1 big minibatch and 2 small minibatches are equivalent."""
multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
self.w, [self.inputs], [self.outputs], [self.output_grads])
multiply_result_small, multiply_inverse_result_small = (
self.runFisherBlockOps(self.w,
np.split(self.inputs, 2),
np.split(self.outputs, 2),
np.split(self.output_grads, 2)))
self.assertAllClose(multiply_result_big, multiply_result_small)
self.assertAllClose(multiply_inverse_result_big,
multiply_inverse_result_small)
def testMultiplyHasBias(self):
result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
[self.outputs], [self.output_grads])
# Clone 'b' along 'input_channels' dimension.
b_filter = np.tile(
np.reshape(self.b, [1, 1, 1, self.output_channels]),
[self.kernel_size, self.kernel_size, 1, 1])
params = np.concatenate([self.w, b_filter], axis=2)
expected_result = self.fisherApprox(True).dot(params.flatten())
# Extract 'b' from concatenated parameters.
expected_result = expected_result.reshape([
self.kernel_size, self.kernel_size, self.input_channels + 1,
self.output_channels
])
expected_result = (expected_result[:, :, 0:-1, :], np.reshape(
expected_result[:, :, -1, :], [self.output_channels]))
self.assertEqual(len(result), 2)
self.assertAllClose(expected_result[0], result[0])
self.assertAllClose(expected_result[1], result[1])
def runFisherBlockOps(self, params, inputs, outputs, output_grads):
"""Run Ops guaranteed by FisherBlock interface.
Args:
params: Tensor or 2-tuple of Tensors. Represents weights or weights and
bias of this layer.
inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
layer.
outputs: list of Tensors of shape [batch_size, output_size].
Preactivations produced by layer.
output_grads: list of Tensors of shape [batch_size, output_size].
Gradient of loss with respect to 'outputs'.
Returns:
multiply_result: Result of FisherBlock.multiply(params)
multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
"""
with ops.Graph().as_default(), self.test_session() as sess:
inputs = as_tensors(inputs)
outputs = as_tensors(outputs)
output_grads = as_tensors(output_grads)
params = as_tensors(params)
block = fb.ConvDiagonalFB(
lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
for (i, o) in zip(inputs, outputs):
block.register_additional_minibatch(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
sess.run(tf_variables.global_variables_initializer())
sess.run(block._factor.make_covariance_update_op(0.0))
multiply_result = sess.run(block.multiply(params))
multiply_inverse_result = sess.run(block.multiply_inverse(params))
return multiply_result, multiply_inverse_result
class ConvKFCBasicFBTest(test.TestCase):
def _testConvKFCBasicFBInitParams(self, params):
@ -583,5 +759,11 @@ class ConvKFCBasicFBTest(test.TestCase):
self.assertAllClose(output_flat, explicit)
def as_tensors(tensor_or_tuple):
"""Converts a potentially nested tuple of np.array to Tensors."""
if isinstance(tensor_or_tuple, (tuple, list)):
return tuple(as_tensors(t) for t in tensor_or_tuple)
return ops.convert_to_tensor(tensor_or_tuple)
if __name__ == '__main__':
test.main()

View File

@ -89,6 +89,10 @@ class LayerCollectionTest(test.TestCase):
lc.register_conv2d(
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
lc.register_conv2d(
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3),
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_generic(
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
lc.register_generic(
@ -96,7 +100,7 @@ class LayerCollectionTest(test.TestCase):
16,
approx=layer_collection.APPROX_DIAGONAL_NAME)
self.assertEqual(5, len(lc.get_blocks()))
self.assertEqual(6, len(lc.get_blocks()))
def testRegisterBlocksMultipleRegistrations(self):
with ops.Graph().as_default():

View File

@ -227,7 +227,7 @@ class FullyConnectedDiagonalFB(FisherBlock):
'w'. For an example 'x' that produces layer inputs 'a' and output
preactivations 's',
v(x, y, w) = vec( x (d loss / d s)^T )
v(x, y, w) = vec( a (d loss / d s)^T )
This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
to the layer's parameters 'w'.
@ -309,13 +309,29 @@ class FullyConnectedDiagonalFB(FisherBlock):
class ConvDiagonalFB(FisherBlock):
"""FisherBlock for convolutional layers using a diagonal approx.
Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
estimator.
Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
into it. We are interested in Fisher(params)[i, i]. This is,
Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
= E[ v(x, y, params)[i] ^ 2 ]
Consider a convoluational layer in this model with (unshared) filter matrix
'w'. For an example image 'x' that produces layer inputs 'a' and output
preactivations 's',
v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )
where 'loc' is a single (x, y) location in an image.
This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
to the layer's parameters 'w'.
"""
# TODO(jamesmartens): add units tests for this class
def __init__(self, layer_collection, params, inputs, outputs, strides,
padding):
def __init__(self, layer_collection, params, strides, padding):
"""Creates a ConvDiagonalFB block.
Args:
@ -325,37 +341,39 @@ class ConvDiagonalFB(FisherBlock):
kernel alone, a Tensor of shape [kernel_height, kernel_width,
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
containing the previous and a Tensor of shape [out_channels].
inputs: A Tensor of shape [batch_size, height, width, in_channels].
Input activations to this layer.
outputs: A Tensor of shape [batch_size, height, width, out_channels].
Output pre-activations from this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
padding: The padding in this layer (e.g. "SAME").
"""
self._inputs = inputs
self._outputs = outputs
self._strides = strides
self._inputs = []
self._outputs = []
self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
self._has_bias = isinstance(params, (tuple, list))
fltr = params[0] if self._has_bias else params
self._filter_shape = tuple(fltr.shape.as_list())
input_shape = tuple(inputs.shape.as_list())
self._num_locations = (
input_shape[1] * input_shape[2] // (strides[1] * strides[2]))
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
# Concatenate inputs, grads_list into single Tensors.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
# Infer number of locations upon which convolution is applied.
inputs_shape = tuple(inputs.shape.as_list())
self._num_locations = (
inputs_shape[1] * inputs_shape[2] //
(self._strides[1] * self._strides[2]))
if NORMALIZE_DAMPING_POWER:
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
self._damping = damping
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
(self._inputs, grads_list, self._filter_shape, self._strides,
self._padding, self._has_bias))
(inputs, grads_list, self._filter_shape, self._strides, self._padding,
self._has_bias))
def multiply_inverse(self, vector):
reshaped_vect = utils.layer_params_to_mat2d(vector)
@ -370,6 +388,18 @@ class ConvDiagonalFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
the convolution.
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
class KroneckerProductFB(FisherBlock):
"""A base class for FisherBlocks with separate input and output factors.

View File

@ -273,9 +273,9 @@ class LayerCollection(object):
fb.ConvKFCBasicFB(self, params, inputs, outputs,
strides, padding))
elif approx == APPROX_DIAGONAL_NAME:
self.register_block(params,
fb.ConvDiagonalFB(self, params, inputs, outputs,
strides, padding))
block = fb.ConvDiagonalFB(self, params, strides, padding)
block.register_additional_minibatch(inputs, outputs)
self.register_block(params, block)
def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
params = params if isinstance(params, (tuple, list)) else (params,)
@ -379,6 +379,27 @@ class LayerCollection(object):
self._loss_dict[name] = loss
def make_or_get_factor(self, cls, args):
"""Insert 'cls(args)' into 'self.fisher_factors' if not already present.
Wraps constructor in 'tf.variable_scope()' to ensure variables constructed
in 'cls.__init__' are placed under this LayerCollection's scope.
Args:
cls: Class that implements FisherFactor.
args: Tuple of arguments to pass into 'cls's constructor. Must be
hashable.
Returns:
Instance of 'cls' found in self.fisher_factors.
"""
try:
hash(args)
except TypeError:
raise TypeError((
"Unable to use (cls, args) = ({}, {}) as a key in "
"LayerCollection.fisher_factors. The pair cannot be hashed."
).format(cls, args))
with variable_scope.variable_scope(self._var_scope):
return utils.setdefault(self.fisher_factors, (cls, args),
lambda: cls(*args))