diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 9b13756e62f..80855da2e92 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -328,17 +328,11 @@ class FullyConnectedDiagonalFB(test.TestCase): multiply_result: Result of FisherBlock.multiply(params) multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) """ - - def _as_tensors(tensor_or_tuple): - if isinstance(tensor_or_tuple, (tuple, list)): - return tuple(ops.convert_to_tensor(t) for t in tensor_or_tuple) - return ops.convert_to_tensor(tensor_or_tuple) - with ops.Graph().as_default(), self.test_session() as sess: - inputs = [_as_tensors(i) for i in inputs] - outputs = [_as_tensors(o) for o in outputs] - output_grads = [_as_tensors(og) for og in output_grads] - params = _as_tensors(params) + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) block = fb.FullyConnectedDiagonalFB( lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) @@ -464,6 +458,188 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class ConvDiagonalFBTest(test.TestCase): + + def setUp(self): + super(ConvDiagonalFBTest, self).setUp() + + self.batch_size = 2 + self.height = 8 + self.width = 4 + self.input_channels = 6 + self.output_channels = 3 + self.kernel_size = 1 + + self.inputs = np.random.randn(self.batch_size, self.height, self.width, + self.input_channels).astype(np.float32) + self.outputs = np.zeros( + [self.batch_size, self.height, self.width, + self.output_channels]).astype(np.float32) + self.output_grads = np.random.randn( + self.batch_size, self.height, self.width, self.output_channels).astype( + np.float32) + self.w = np.random.randn(self.kernel_size, self.kernel_size, + self.input_channels, self.output_channels).astype( + np.float32) + self.b = np.random.randn(self.output_channels).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, + np.ones([self.batch_size, self.height, self.width, 1])], + axis=-1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads, + self.kernel_size) + + def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): + r"""Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] + + where the expectation is taken over examples and the sum over (x, y) + locations upon which the convolution is applied. + + Args: + inputs: np.array of shape [batch_size, height, width, input_channels]. + output_grads: np.array of shape [batch_size, height, width, + output_channels]. + kernel_size: int. height and width of kernel. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + kernel_size^2 * input_channels * output_channels. + """ + batch_size, height, width, input_channels = inputs.shape + assert output_grads.shape[0] == batch_size + assert output_grads.shape[1] == height + assert output_grads.shape[2] == width + output_channels = output_grads.shape[3] + + # If kernel_size == 1, then we don't need to worry about capturing context + # around the pixel upon which a convolution is applied. This makes testing + # easier. + assert kernel_size == 1, "kernel_size != 1 isn't supported." + num_locations = height * width + inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) + output_grads = np.reshape(output_grads, + [batch_size, num_locations, output_channels]) + + fisher_diag = np.zeros((input_channels, output_channels)) + for i in range(batch_size): + # Each example's approximation is a square(sum-of-outer-products). + example_fisher_diag = np.zeros((input_channels, output_channels)) + for j in range(num_locations): + example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) + fisher_diag += np.square(example_fisher_diag) + + # Normalize by batch_size (not num_locations). + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result, atol=1e-3) + + def testRegisterAdditionalMinibatch(self): + """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, + np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + # Clone 'b' along 'input_channels' dimension. + b_filter = np.tile( + np.reshape(self.b, [1, 1, 1, self.output_channels]), + [self.kernel_size, self.kernel_size, 1, 1]) + params = np.concatenate([self.w, b_filter], axis=2) + expected_result = self.fisherApprox(True).dot(params.flatten()) + + # Extract 'b' from concatenated parameters. + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels + 1, + self.output_channels + ]) + expected_result = (expected_result[:, :, 0:-1, :], np.reshape( + expected_result[:, :, -1, :], [self.output_channels])) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.ConvDiagonalFB( + lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') + for (i, o) in zip(inputs, outputs): + block.register_additional_minibatch(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + class ConvKFCBasicFBTest(test.TestCase): def _testConvKFCBasicFBInitParams(self, params): @@ -583,5 +759,11 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +def as_tensors(tensor_or_tuple): + """Converts a potentially nested tuple of np.array to Tensors.""" + if isinstance(tensor_or_tuple, (tuple, list)): + return tuple(as_tensors(t) for t in tensor_or_tuple) + return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 53d40da586c..b444e871701 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -89,6 +89,10 @@ class LayerCollectionTest(test.TestCase): lc.register_conv2d( array_ops.constant(4), [1, 1, 1, 1], 'SAME', array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + lc.register_conv2d( + array_ops.constant(4), [1, 1, 1, 1], 'SAME', + array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( @@ -96,7 +100,7 @@ class LayerCollectionTest(test.TestCase): 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - self.assertEqual(5, len(lc.get_blocks())) + self.assertEqual(6, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 6cca2272d7d..5e822b5fe32 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -227,7 +227,7 @@ class FullyConnectedDiagonalFB(FisherBlock): 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( x (d loss / d s)^T ) + v(x, y, w) = vec( a (d loss / d s)^T ) This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. @@ -309,13 +309,29 @@ class FullyConnectedDiagonalFB(FisherBlock): class ConvDiagonalFB(FisherBlock): """FisherBlock for convolutional layers using a diagonal approx. - Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator. + Estimates the Fisher Information matrix's diagonal entries for a convolutional + layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" + estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ] + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For an example image 'x' that produces layer inputs 'a' and output + preactivations 's', + + v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + + where 'loc' is a single (x, y) location in an image. + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. """ - # TODO(jamesmartens): add units tests for this class - - def __init__(self, layer_collection, params, inputs, outputs, strides, - padding): + def __init__(self, layer_collection, params, strides, padding): """Creates a ConvDiagonalFB block. Args: @@ -325,37 +341,39 @@ class ConvDiagonalFB(FisherBlock): kernel alone, a Tensor of shape [kernel_height, kernel_width, in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - inputs: A Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. - outputs: A Tensor of shape [batch_size, height, width, out_channels]. - Output pre-activations from this layer. strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). + padding: The padding in this layer (e.g. "SAME"). """ - self._inputs = inputs - self._outputs = outputs - self._strides = strides + self._inputs = [] + self._outputs = [] + self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) - input_shape = tuple(inputs.shape.as_list()) - self._num_locations = ( - input_shape[1] * input_shape[2] // (strides[1] * strides[2])) - super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + # Concatenate inputs, grads_list into single Tensors. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + # Infer number of locations upon which convolution is applied. + inputs_shape = tuple(inputs.shape.as_list()) + self._num_locations = ( + inputs_shape[1] * inputs_shape[2] // + (self._strides[1] * self._strides[2])) + if NORMALIZE_DAMPING_POWER: damping /= self._num_locations**NORMALIZE_DAMPING_POWER self._damping = damping self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, - (self._inputs, grads_list, self._filter_shape, self._strides, - self._padding, self._has_bias)) + (inputs, grads_list, self._filter_shape, self._strides, self._padding, + self._has_bias)) def multiply_inverse(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) @@ -370,6 +388,18 @@ class ConvDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._outputs + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to + the convolution. + outputs: Tensor of shape [batch_size, height, width, output_size]. Layer + preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + class KroneckerProductFB(FisherBlock): """A base class for FisherBlocks with separate input and output factors. diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index beb8ef136e3..10ef5543516 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -273,9 +273,9 @@ class LayerCollection(object): fb.ConvKFCBasicFB(self, params, inputs, outputs, strides, padding)) elif approx == APPROX_DIAGONAL_NAME: - self.register_block(params, - fb.ConvDiagonalFB(self, params, inputs, outputs, - strides, padding)) + block = fb.ConvDiagonalFB(self, params, strides, padding) + block.register_additional_minibatch(inputs, outputs) + self.register_block(params, block) def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME): params = params if isinstance(params, (tuple, list)) else (params,) @@ -379,6 +379,27 @@ class LayerCollection(object): self._loss_dict[name] = loss def make_or_get_factor(self, cls, args): + """Insert 'cls(args)' into 'self.fisher_factors' if not already present. + + Wraps constructor in 'tf.variable_scope()' to ensure variables constructed + in 'cls.__init__' are placed under this LayerCollection's scope. + + Args: + cls: Class that implements FisherFactor. + args: Tuple of arguments to pass into 'cls's constructor. Must be + hashable. + + Returns: + Instance of 'cls' found in self.fisher_factors. + """ + try: + hash(args) + except TypeError: + raise TypeError(( + "Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed." + ).format(cls, args)) + with variable_scope.variable_scope(self._var_scope): return utils.setdefault(self.fisher_factors, (cls, args), lambda: cls(*args))