K-FAC: Multi-tower support for ConvDiagonalFB.
PiperOrigin-RevId: 173105412
This commit is contained in:
parent
fd8d517b97
commit
eea089bdb6
tensorflow/contrib/kfac/python
@ -328,17 +328,11 @@ class FullyConnectedDiagonalFB(test.TestCase):
|
||||
multiply_result: Result of FisherBlock.multiply(params)
|
||||
multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
|
||||
"""
|
||||
|
||||
def _as_tensors(tensor_or_tuple):
|
||||
if isinstance(tensor_or_tuple, (tuple, list)):
|
||||
return tuple(ops.convert_to_tensor(t) for t in tensor_or_tuple)
|
||||
return ops.convert_to_tensor(tensor_or_tuple)
|
||||
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
inputs = [_as_tensors(i) for i in inputs]
|
||||
outputs = [_as_tensors(o) for o in outputs]
|
||||
output_grads = [_as_tensors(og) for og in output_grads]
|
||||
params = _as_tensors(params)
|
||||
inputs = as_tensors(inputs)
|
||||
outputs = as_tensors(outputs)
|
||||
output_grads = as_tensors(output_grads)
|
||||
params = as_tensors(params)
|
||||
|
||||
block = fb.FullyConnectedDiagonalFB(
|
||||
lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
|
||||
@ -464,6 +458,188 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
self.assertAllClose(output_flat, explicit)
|
||||
|
||||
|
||||
class ConvDiagonalFBTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ConvDiagonalFBTest, self).setUp()
|
||||
|
||||
self.batch_size = 2
|
||||
self.height = 8
|
||||
self.width = 4
|
||||
self.input_channels = 6
|
||||
self.output_channels = 3
|
||||
self.kernel_size = 1
|
||||
|
||||
self.inputs = np.random.randn(self.batch_size, self.height, self.width,
|
||||
self.input_channels).astype(np.float32)
|
||||
self.outputs = np.zeros(
|
||||
[self.batch_size, self.height, self.width,
|
||||
self.output_channels]).astype(np.float32)
|
||||
self.output_grads = np.random.randn(
|
||||
self.batch_size, self.height, self.width, self.output_channels).astype(
|
||||
np.float32)
|
||||
self.w = np.random.randn(self.kernel_size, self.kernel_size,
|
||||
self.input_channels, self.output_channels).astype(
|
||||
np.float32)
|
||||
self.b = np.random.randn(self.output_channels).astype(np.float32)
|
||||
|
||||
def fisherApprox(self, has_bias=False):
|
||||
"""Fisher approximation using default inputs."""
|
||||
if has_bias:
|
||||
inputs = np.concatenate(
|
||||
[self.inputs,
|
||||
np.ones([self.batch_size, self.height, self.width, 1])],
|
||||
axis=-1)
|
||||
else:
|
||||
inputs = self.inputs
|
||||
return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
|
||||
self.kernel_size)
|
||||
|
||||
def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
|
||||
r"""Builds explicit diagonal Fisher approximation.
|
||||
|
||||
Fisher's diagonal is (d loss / d w)'s elements squared for
|
||||
d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
|
||||
|
||||
where the expectation is taken over examples and the sum over (x, y)
|
||||
locations upon which the convolution is applied.
|
||||
|
||||
Args:
|
||||
inputs: np.array of shape [batch_size, height, width, input_channels].
|
||||
output_grads: np.array of shape [batch_size, height, width,
|
||||
output_channels].
|
||||
kernel_size: int. height and width of kernel.
|
||||
|
||||
Returns:
|
||||
Diagonal np.array of shape [num_params, num_params] for num_params =
|
||||
kernel_size^2 * input_channels * output_channels.
|
||||
"""
|
||||
batch_size, height, width, input_channels = inputs.shape
|
||||
assert output_grads.shape[0] == batch_size
|
||||
assert output_grads.shape[1] == height
|
||||
assert output_grads.shape[2] == width
|
||||
output_channels = output_grads.shape[3]
|
||||
|
||||
# If kernel_size == 1, then we don't need to worry about capturing context
|
||||
# around the pixel upon which a convolution is applied. This makes testing
|
||||
# easier.
|
||||
assert kernel_size == 1, "kernel_size != 1 isn't supported."
|
||||
num_locations = height * width
|
||||
inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
|
||||
output_grads = np.reshape(output_grads,
|
||||
[batch_size, num_locations, output_channels])
|
||||
|
||||
fisher_diag = np.zeros((input_channels, output_channels))
|
||||
for i in range(batch_size):
|
||||
# Each example's approximation is a square(sum-of-outer-products).
|
||||
example_fisher_diag = np.zeros((input_channels, output_channels))
|
||||
for j in range(num_locations):
|
||||
example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
|
||||
fisher_diag += np.square(example_fisher_diag)
|
||||
|
||||
# Normalize by batch_size (not num_locations).
|
||||
return np.diag(fisher_diag.flatten()) / batch_size
|
||||
|
||||
def testMultiply(self):
|
||||
result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
|
||||
[self.output_grads])
|
||||
|
||||
# Construct Fisher-vector product.
|
||||
expected_result = self.fisherApprox().dot(self.w.flatten())
|
||||
expected_result = expected_result.reshape([
|
||||
self.kernel_size, self.kernel_size, self.input_channels,
|
||||
self.output_channels
|
||||
])
|
||||
|
||||
self.assertAllClose(expected_result, result)
|
||||
|
||||
def testMultiplyInverse(self):
|
||||
_, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
|
||||
[self.output_grads])
|
||||
|
||||
# Construct inverse Fisher-vector product.
|
||||
expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
|
||||
expected_result = expected_result.reshape([
|
||||
self.kernel_size, self.kernel_size, self.input_channels,
|
||||
self.output_channels
|
||||
])
|
||||
|
||||
self.assertAllClose(expected_result, result, atol=1e-3)
|
||||
|
||||
def testRegisterAdditionalMinibatch(self):
|
||||
"""Ensure 1 big minibatch and 2 small minibatches are equivalent."""
|
||||
multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
|
||||
self.w, [self.inputs], [self.outputs], [self.output_grads])
|
||||
multiply_result_small, multiply_inverse_result_small = (
|
||||
self.runFisherBlockOps(self.w,
|
||||
np.split(self.inputs, 2),
|
||||
np.split(self.outputs, 2),
|
||||
np.split(self.output_grads, 2)))
|
||||
|
||||
self.assertAllClose(multiply_result_big, multiply_result_small)
|
||||
self.assertAllClose(multiply_inverse_result_big,
|
||||
multiply_inverse_result_small)
|
||||
|
||||
def testMultiplyHasBias(self):
|
||||
result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
|
||||
[self.outputs], [self.output_grads])
|
||||
# Clone 'b' along 'input_channels' dimension.
|
||||
b_filter = np.tile(
|
||||
np.reshape(self.b, [1, 1, 1, self.output_channels]),
|
||||
[self.kernel_size, self.kernel_size, 1, 1])
|
||||
params = np.concatenate([self.w, b_filter], axis=2)
|
||||
expected_result = self.fisherApprox(True).dot(params.flatten())
|
||||
|
||||
# Extract 'b' from concatenated parameters.
|
||||
expected_result = expected_result.reshape([
|
||||
self.kernel_size, self.kernel_size, self.input_channels + 1,
|
||||
self.output_channels
|
||||
])
|
||||
expected_result = (expected_result[:, :, 0:-1, :], np.reshape(
|
||||
expected_result[:, :, -1, :], [self.output_channels]))
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertAllClose(expected_result[0], result[0])
|
||||
self.assertAllClose(expected_result[1], result[1])
|
||||
|
||||
def runFisherBlockOps(self, params, inputs, outputs, output_grads):
|
||||
"""Run Ops guaranteed by FisherBlock interface.
|
||||
|
||||
Args:
|
||||
params: Tensor or 2-tuple of Tensors. Represents weights or weights and
|
||||
bias of this layer.
|
||||
inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
|
||||
layer.
|
||||
outputs: list of Tensors of shape [batch_size, output_size].
|
||||
Preactivations produced by layer.
|
||||
output_grads: list of Tensors of shape [batch_size, output_size].
|
||||
Gradient of loss with respect to 'outputs'.
|
||||
|
||||
Returns:
|
||||
multiply_result: Result of FisherBlock.multiply(params)
|
||||
multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
|
||||
"""
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
inputs = as_tensors(inputs)
|
||||
outputs = as_tensors(outputs)
|
||||
output_grads = as_tensors(output_grads)
|
||||
params = as_tensors(params)
|
||||
|
||||
block = fb.ConvDiagonalFB(
|
||||
lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
|
||||
for (i, o) in zip(inputs, outputs):
|
||||
block.register_additional_minibatch(i, o)
|
||||
|
||||
block.instantiate_factors((output_grads,), damping=0.0)
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(block._factor.make_covariance_update_op(0.0))
|
||||
multiply_result = sess.run(block.multiply(params))
|
||||
multiply_inverse_result = sess.run(block.multiply_inverse(params))
|
||||
|
||||
return multiply_result, multiply_inverse_result
|
||||
|
||||
|
||||
class ConvKFCBasicFBTest(test.TestCase):
|
||||
|
||||
def _testConvKFCBasicFBInitParams(self, params):
|
||||
@ -583,5 +759,11 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
self.assertAllClose(output_flat, explicit)
|
||||
|
||||
|
||||
def as_tensors(tensor_or_tuple):
|
||||
"""Converts a potentially nested tuple of np.array to Tensors."""
|
||||
if isinstance(tensor_or_tuple, (tuple, list)):
|
||||
return tuple(as_tensors(t) for t in tensor_or_tuple)
|
||||
return ops.convert_to_tensor(tensor_or_tuple)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -89,6 +89,10 @@ class LayerCollectionTest(test.TestCase):
|
||||
lc.register_conv2d(
|
||||
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
|
||||
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
|
||||
lc.register_conv2d(
|
||||
array_ops.constant(4), [1, 1, 1, 1], 'SAME',
|
||||
array_ops.ones((1, 1, 1, 1)), array_ops.constant(3),
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.register_generic(
|
||||
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
|
||||
lc.register_generic(
|
||||
@ -96,7 +100,7 @@ class LayerCollectionTest(test.TestCase):
|
||||
16,
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
|
||||
self.assertEqual(5, len(lc.get_blocks()))
|
||||
self.assertEqual(6, len(lc.get_blocks()))
|
||||
|
||||
def testRegisterBlocksMultipleRegistrations(self):
|
||||
with ops.Graph().as_default():
|
||||
|
@ -227,7 +227,7 @@ class FullyConnectedDiagonalFB(FisherBlock):
|
||||
'w'. For an example 'x' that produces layer inputs 'a' and output
|
||||
preactivations 's',
|
||||
|
||||
v(x, y, w) = vec( x (d loss / d s)^T )
|
||||
v(x, y, w) = vec( a (d loss / d s)^T )
|
||||
|
||||
This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
|
||||
to the layer's parameters 'w'.
|
||||
@ -309,13 +309,29 @@ class FullyConnectedDiagonalFB(FisherBlock):
|
||||
class ConvDiagonalFB(FisherBlock):
|
||||
"""FisherBlock for convolutional layers using a diagonal approx.
|
||||
|
||||
Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator.
|
||||
Estimates the Fisher Information matrix's diagonal entries for a convolutional
|
||||
layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
|
||||
estimator.
|
||||
|
||||
Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
|
||||
into it. We are interested in Fisher(params)[i, i]. This is,
|
||||
|
||||
Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
|
||||
= E[ v(x, y, params)[i] ^ 2 ]
|
||||
|
||||
Consider a convoluational layer in this model with (unshared) filter matrix
|
||||
'w'. For an example image 'x' that produces layer inputs 'a' and output
|
||||
preactivations 's',
|
||||
|
||||
v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )
|
||||
|
||||
where 'loc' is a single (x, y) location in an image.
|
||||
|
||||
This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
|
||||
to the layer's parameters 'w'.
|
||||
"""
|
||||
|
||||
# TODO(jamesmartens): add units tests for this class
|
||||
|
||||
def __init__(self, layer_collection, params, inputs, outputs, strides,
|
||||
padding):
|
||||
def __init__(self, layer_collection, params, strides, padding):
|
||||
"""Creates a ConvDiagonalFB block.
|
||||
|
||||
Args:
|
||||
@ -325,37 +341,39 @@ class ConvDiagonalFB(FisherBlock):
|
||||
kernel alone, a Tensor of shape [kernel_height, kernel_width,
|
||||
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
|
||||
containing the previous and a Tensor of shape [out_channels].
|
||||
inputs: A Tensor of shape [batch_size, height, width, in_channels].
|
||||
Input activations to this layer.
|
||||
outputs: A Tensor of shape [batch_size, height, width, out_channels].
|
||||
Output pre-activations from this layer.
|
||||
strides: The stride size in this layer (1-D Tensor of length 4).
|
||||
padding: The padding in this layer (1-D of Tensor length 4).
|
||||
padding: The padding in this layer (e.g. "SAME").
|
||||
"""
|
||||
self._inputs = inputs
|
||||
self._outputs = outputs
|
||||
self._strides = strides
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._strides = tuple(strides) if isinstance(strides, list) else strides
|
||||
self._padding = padding
|
||||
self._has_bias = isinstance(params, (tuple, list))
|
||||
|
||||
fltr = params[0] if self._has_bias else params
|
||||
self._filter_shape = tuple(fltr.shape.as_list())
|
||||
|
||||
input_shape = tuple(inputs.shape.as_list())
|
||||
self._num_locations = (
|
||||
input_shape[1] * input_shape[2] // (strides[1] * strides[2]))
|
||||
|
||||
super(ConvDiagonalFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
# Concatenate inputs, grads_list into single Tensors.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
|
||||
# Infer number of locations upon which convolution is applied.
|
||||
inputs_shape = tuple(inputs.shape.as_list())
|
||||
self._num_locations = (
|
||||
inputs_shape[1] * inputs_shape[2] //
|
||||
(self._strides[1] * self._strides[2]))
|
||||
|
||||
if NORMALIZE_DAMPING_POWER:
|
||||
damping /= self._num_locations**NORMALIZE_DAMPING_POWER
|
||||
self._damping = damping
|
||||
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.ConvDiagonalFactor,
|
||||
(self._inputs, grads_list, self._filter_shape, self._strides,
|
||||
self._padding, self._has_bias))
|
||||
(inputs, grads_list, self._filter_shape, self._strides, self._padding,
|
||||
self._has_bias))
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
reshaped_vect = utils.layer_params_to_mat2d(vector)
|
||||
@ -370,6 +388,18 @@ class ConvDiagonalFB(FisherBlock):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
|
||||
the convolution.
|
||||
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
|
||||
preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
|
||||
class KroneckerProductFB(FisherBlock):
|
||||
"""A base class for FisherBlocks with separate input and output factors.
|
||||
|
@ -273,9 +273,9 @@ class LayerCollection(object):
|
||||
fb.ConvKFCBasicFB(self, params, inputs, outputs,
|
||||
strides, padding))
|
||||
elif approx == APPROX_DIAGONAL_NAME:
|
||||
self.register_block(params,
|
||||
fb.ConvDiagonalFB(self, params, inputs, outputs,
|
||||
strides, padding))
|
||||
block = fb.ConvDiagonalFB(self, params, strides, padding)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.register_block(params, block)
|
||||
|
||||
def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
|
||||
params = params if isinstance(params, (tuple, list)) else (params,)
|
||||
@ -379,6 +379,27 @@ class LayerCollection(object):
|
||||
self._loss_dict[name] = loss
|
||||
|
||||
def make_or_get_factor(self, cls, args):
|
||||
"""Insert 'cls(args)' into 'self.fisher_factors' if not already present.
|
||||
|
||||
Wraps constructor in 'tf.variable_scope()' to ensure variables constructed
|
||||
in 'cls.__init__' are placed under this LayerCollection's scope.
|
||||
|
||||
Args:
|
||||
cls: Class that implements FisherFactor.
|
||||
args: Tuple of arguments to pass into 'cls's constructor. Must be
|
||||
hashable.
|
||||
|
||||
Returns:
|
||||
Instance of 'cls' found in self.fisher_factors.
|
||||
"""
|
||||
try:
|
||||
hash(args)
|
||||
except TypeError:
|
||||
raise TypeError((
|
||||
"Unable to use (cls, args) = ({}, {}) as a key in "
|
||||
"LayerCollection.fisher_factors. The pair cannot be hashed."
|
||||
).format(cls, args))
|
||||
|
||||
with variable_scope.variable_scope(self._var_scope):
|
||||
return utils.setdefault(self.fisher_factors, (cls, args),
|
||||
lambda: cls(*args))
|
||||
|
Loading…
Reference in New Issue
Block a user