- FisherEstimator now supports computing products with arbitrary matrix powers of the approximate Fisher

- Added multi-tower support to multi/RNN fully connected layers
- All op creation is now done inside functions that explicitly create ops, thus allowing fine control of their placement. One result of this is that we no longer need any colocation statements (and these have been removed)
- Multi-tower computations are now handled using ParitionedTensor class, which appears to be a single tensor to the FisherFactors but actually contains a list of tensors.
- To achieve the above damping values are passed around as special functions that are packaged along with "ids" that can be used to uniquely identify the computation they perform.  Topohash might provide a better solution for this in the future.
- Variable creation in the factors is now done via special methods so we can have fine control over where these are placed
- FisherEstimator now has special functions to create ops and variables using different placement strategies (currently: no strategy, round-robin, and as thunks).  By default this will use the round-robin strategy and manufacture the usual convenience properties ("inv_update_ops", etc).  This default behavior is to preserve backwards compatibility but in the future we should deprecate this and require the user to ask for an explicit strategy.
- LossFunctions no longer make any ops in their constructors.  The only make ops when evaluated.  LayerCollection maintains a list of tensors/ops which we can colocate LossFunction computations with (typically their inputs)
- LossFunctions no longer support multi-tower/mini-batches directly. Instead LayerCollection maintains a list of these objects, one for each tower.  This solution is better since now the loss function related computations can take place exclusively on the corresponding tower.
- All loss functions now support multiple towers/minibatches (via LayerCollection).
- tf.gradients is passed list of loss function values instead of their sum, which will prevent extraneous gradient ops being placed on arbitrary devices.  Hopefully with this change and the above one for loss functions all ops associated with gradient computations (for computing stats) will occur completely on the device that defines that part of the graph.  e.g. this will do the right thing for multiple towers
- I've also made sure that sensible colocation occurs for the extra ops needed by the curvature_propagation and exact estimation modes.
- Variables and ops made by FisherEstimator are now placed inside of name scopes (based on the name given to FisherEstimator)
- Restored old variable use count tracker implementation, thus fixing the issue with how generic registrations were handled by check_registration().
- Restored interface to FisherEstimator (which was changed in the previous CL).
- Fixed bug in LazyKFacOptimizer: optional/named arguments weren't being passed in properly
- Lots of other minor refactors/improvements

PiperOrigin-RevId: 188310846
This commit is contained in:
A. Unique TensorFlower 2018-03-08 04:11:24 -08:00 committed by TensorFlower Gardener
parent e52f916b87
commit 4ac1fee7f1
13 changed files with 1634 additions and 1148 deletions

View File

@ -36,6 +36,7 @@ py_test(
srcs = ["fisher_factors_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -90,59 +90,75 @@ class EstimatorTest(test.TestCase):
def testEstimatorInitManualRegistration(self):
with self._graph.as_default():
# We should be able to build an estimator for only the registered vars.
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection)
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
estimator.FisherEstimator(lambda: 0.2, [self.weights, self.bias], 0.1,
estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
self.layer_collection)
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
estimator.FisherEstimator(lambda: 0.2, [], 0.1, self.layer_collection)
estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection)
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
self.layer_collection, "not_a_real_mode")
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection,
estimation_mode="not_a_real_mode")
def testModeListCorrect(self):
def testGradientsModeBuild(self):
with self._graph.as_default():
est = estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
self.layer_collection)
self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection,
estimation_mode="gradients")
def testAllModesBuild(self):
for mode in _ALL_ESTIMATION_MODES:
with self._graph.as_default():
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
self.layer_collection, mode)
def testEmpiricalModeBuild(self):
with self._graph.as_default():
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection,
estimation_mode="empirical")
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection,
estimation_mode="curvature_prop")
def testExactModeBuild(self):
with self._graph.as_default():
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection,
estimation_mode="exact")
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
cov_ema_decay=0.0)
# Construct an op that executes one covariance update per step.
global_step = training_util.get_or_create_global_step()
(cov_variable_thunks, cov_update_op_thunks,
_, _) = fisher_estimator.create_ops_and_vars_thunks()
for thunk in cov_variable_thunks:
thunk()
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
cov_update_op_thunks = fisher_estimator.cov_update_thunks
cov_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(cov_update_op_thunks)])
@ -178,19 +194,24 @@ class EstimatorTest(test.TestCase):
"""Ensures inverse update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
cov_ema_decay=0.0)
# Construct op that updates one inverse per global step.
global_step = training_util.get_or_create_global_step()
(cov_variable_thunks, _, inv_variable_thunks,
inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
for thunk in cov_variable_thunks:
thunk()
for thunk in inv_variable_thunks:
thunk()
inv_matrices = [
matrix
for fisher_factor in self.layer_collection.get_factors()
for matrix in fisher_factor._inverses_by_damping.values()
for matrix in fisher_factor._matpower_by_exp_and_damping.values()
]
inv_update_op_thunks = fisher_estimator.inv_update_thunks
inv_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(inv_update_op_thunks)])

View File

@ -94,6 +94,9 @@ class FullFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
block.register_inverse()
block._factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -112,6 +115,9 @@ class FullFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
block.register_inverse()
block._factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -131,6 +137,9 @@ class FullFBTest(test.TestCase):
grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
damping = 0.5
block.instantiate_factors((grads,), damping)
block._factor.instantiate_cov_variables()
block.register_inverse()
block._factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
@ -185,6 +194,7 @@ class NaiveDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -203,6 +213,7 @@ class NaiveDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
block._factor.instantiate_cov_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -221,6 +232,7 @@ class NaiveDiagonalFBTest(test.TestCase):
grads = (params[0]**2, math_ops.sqrt(params[1]))
damping = 0.5
block.instantiate_factors((grads,), damping)
block._factor.instantiate_cov_variables()
cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
sess.run(state_ops.assign(block._factor._cov, cov))
@ -367,6 +379,7 @@ class FullyConnectedDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
block._factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
sess.run(block._factor.make_covariance_update_op(0.0))
@ -394,7 +407,7 @@ class EmbeddingKFACFBTest(test.TestCase):
# Instantiate factor's variables. Ensure it doesn't fail.
grads = outputs**2.
damping = array_ops.constant(0.)
block.instantiate_factors(([grads],), damping)
block.instantiate_factors(((grads,),), damping)
def testMultiplyInverse(self):
with ops.Graph().as_default(), self.test_session() as sess:
@ -412,7 +425,12 @@ class EmbeddingKFACFBTest(test.TestCase):
# Instantiate factor's variables. Ensure it doesn't fail.
grads = outputs**2.
damping = array_ops.constant(0.)
block.instantiate_factors(([grads],), damping)
block.instantiate_factors(((grads,),), damping)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
# Create a sparse update.
indices = array_ops.constant([1, 3, 4])
@ -456,7 +474,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
def testInstantiateFactorsNoBias(self):
with ops.Graph().as_default():
@ -467,7 +485,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
def testMultiplyInverseTuple(self):
with ops.Graph().as_default(), self.test_session() as sess:
@ -477,7 +495,13 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -503,7 +527,12 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -527,10 +556,17 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
block.instantiate_factors(([grads],), damping)
block.instantiate_factors(((grads,),), damping)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
sess.run(block._input_factor.make_inverse_update_ops())
sess.run(block._output_factor.make_inverse_update_ops())
@ -718,6 +754,7 @@ class ConvDiagonalFBTest(test.TestCase):
block.register_additional_minibatch(i, o)
block.instantiate_factors((output_grads,), damping=0.0)
block._factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
sess.run(block._factor.make_covariance_update_op(0.0))
@ -759,7 +796,12 @@ class ConvKFCBasicFBTest(test.TestCase):
'SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -786,7 +828,12 @@ class ConvKFCBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
self.assertFalse(block._has_bias)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -809,7 +856,12 @@ class ConvKFCBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
self.assertTrue(block._has_bias)
grads = outputs**2
block.instantiate_factors(([grads],), 0.5)
block.instantiate_factors(((grads,),), 0.5)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
# Make sure our inverse is something other than the identity.
sess.run(tf_variables.global_variables_initializer())
@ -832,7 +884,12 @@ class ConvKFCBasicFBTest(test.TestCase):
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
block.instantiate_factors(([grads],), damping)
block.instantiate_factors(((grads,),), damping)
block._input_factor.instantiate_cov_variables()
block._output_factor.instantiate_cov_variables()
block.register_inverse()
block._input_factor.instantiate_inv_variables()
block._output_factor.instantiate_inv_variables()
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
@ -857,9 +914,9 @@ class FullyConnectedSeriesFBTest(test.TestCase):
random_seed.set_random_seed(200)
inputs = array_ops.constant([1., 2.])
outputs = array_ops.constant([3., 4.])
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(), inputs=[inputs], outputs=[outputs])
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
block.register_additional_minibatch([inputs], [outputs])
self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
def testInstantiateFactorsHasBias(self):
with ops.Graph().as_default():
@ -868,11 +925,10 @@ class FullyConnectedSeriesFBTest(test.TestCase):
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(),
inputs=[inputs],
outputs=[outputs],
has_bias=True)
block.register_additional_minibatch([inputs], [outputs])
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
block.instantiate_factors((((grads,),),), 0.5)
def testInstantiateFactorsNoBias(self):
with ops.Graph().as_default():
@ -881,11 +937,10 @@ class FullyConnectedSeriesFBTest(test.TestCase):
outputs = array_ops.constant([[3., 4.], [5., 6.]])
block = fb.FullyConnectedSeriesFB(
lc.LayerCollection(),
inputs=[inputs],
outputs=[outputs],
has_bias=False)
block.register_additional_minibatch([inputs], [outputs])
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
block.instantiate_factors((((grads,),),), 0.5)
def as_tensors(tensor_or_tuple):

View File

@ -21,8 +21,8 @@ from __future__ import print_function
import numpy as np
import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import random_seed
@ -33,32 +33,8 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
class MaybeColocateTest(test.TestCase):
def setUp(self):
self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS
def tearDown(self):
ff.set_global_constants(
colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs)
def testFalse(self):
ff.set_global_constants(colocate_cov_ops_with_inputs=False)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
with ff.maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@b'], b.op.colocation_groups())
def testTrue(self):
ff.set_global_constants(colocate_cov_ops_with_inputs=True)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
with ff.maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@a'], b.op.colocation_groups())
def make_damping_func(damping):
return fb._package_func(lambda: damping, damping)
class FisherFactorTestingDummy(ff.FisherFactor):
@ -98,10 +74,13 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def right_multiply(self, x, damping):
return NotImplementedError
def left_multiply_inverse(self, x, damping):
def left_multiply_matpower(self, x, exp, damping):
return NotImplementedError
def right_multiply_inverse(self, x, damping):
def right_multiply_matpower(self, x, exp, damping):
return NotImplementedError
def instantiate_inv_variables(self):
return NotImplementedError
@ -246,21 +225,24 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
dampings = 0.1, 1e-1, 0.00001, 1e-5
damping_funcs = [make_damping_func(0.1),
make_damping_func(0.1),
make_damping_func(1e-5),
make_damping_func(1e-5)]
for damping_func in damping_funcs:
factor.register_inverse(damping_func)
for damping in dampings:
factor.register_damped_inverse(damping)
factor.instantiate_inv_variables()
self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys()))
inv = factor._inverses_by_damping[dampings[0]]
self.assertEqual(inv, factor._inverses_by_damping[dampings[1]])
self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]])
self.assertEqual(factor._inverses_by_damping[dampings[2]],
factor._inverses_by_damping[dampings[3]])
inv = factor.get_inverse(damping_funcs[0])
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
self.assertEqual(factor.get_inverse(damping_funcs[2]),
factor.get_inverse(damping_funcs[3]))
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]],
factor_vars)
self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
set(factor_vars))
self.assertEqual(shape, inv.get_shape())
def testRegisterMatpower(self):
@ -270,17 +252,22 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
factor.register_matpower(1, 0.5)
factor.register_matpower(2, 0.5)
# TODO(b/74201126): Change to using the same func for both once
# Topohash is in place.
damping_func_1 = make_damping_func(0.5)
damping_func_2 = make_damping_func(0.5)
factor.register_matpower(-0.5, damping_func_1)
factor.register_matpower(2, damping_func_2)
factor.instantiate_inv_variables()
self.assertEqual(
set([(1, 0.5), (2, 0.5)]),
set(factor._matpower_by_exp_and_damping.keys()))
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
matpower1 = factor.get_matpower(1, 0.5)
matpower2 = factor.get_matpower(2, 0.5)
self.assertListEqual([matpower1, matpower2], factor_vars)
matpower1 = factor.get_matpower(-0.5, damping_func_1)
matpower2 = factor.get_matpower(2, damping_func_2)
self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
self.assertEqual(shape, matpower1.get_shape())
self.assertEqual(shape, matpower2.get_shape())
@ -299,17 +286,24 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_funcs = []
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
factor.register_damped_inverse(1. / i)
damping_funcs.append(make_damping_func(1./i))
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
factor.register_inverse(damping_funcs[i])
factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
new_invs = []
sess.run(ops)
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
# The inverse op will assign the damped inverse of cov to the inv var.
new_invs.append(sess.run(factor._inverses_by_damping[1. / i]))
new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
# We want to see that the new invs are all different from each other.
for i in range(len(new_invs)):
for j in range(i + 1, len(new_invs)):
@ -324,14 +318,16 @@ class InverseProvidingFactorTest(test.TestCase):
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
damping = 0.5
damping_func = make_damping_func(damping)
factor.register_matpower(exp, damping)
factor.register_matpower(exp, damping_func)
factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
sess.run(ops[0])
matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)])
matpower = sess.run(factor.get_matpower(exp, damping_func))
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
self.assertAllClose(matpower, matpower_np)
@ -342,18 +338,21 @@ class InverseProvidingFactorTest(test.TestCase):
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
factor.register_damped_inverse(0)
damping_func = make_damping_func(0)
factor.register_inverse(damping_func)
factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
# The inverse op will assign the damped inverse of cov to the inv var.
old_inv = sess.run(factor._inverses_by_damping[0])
old_inv = sess.run(factor.get_inverse(damping_func))
self.assertAllClose(
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
sess.run(ops)
new_inv = sess.run(factor._inverses_by_damping[0])
new_inv = sess.run(factor.get_inverse(damping_func))
self.assertAllClose(new_inv, np.linalg.inv(cov))
@ -364,6 +363,7 @@ class FullFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.FullFactor((tensor,), 32)
factor.instantiate_cov_variables()
self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
def testFullFactorInitFloat64(self):
@ -372,6 +372,7 @@ class FullFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullFactor((tensor,), 32)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 6], cov.get_shape().as_list())
@ -381,6 +382,7 @@ class FullFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([1., 2.], name='a/b/c')
factor = ff.FullFactor((tensor,), 2)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -394,6 +396,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self):
@ -402,6 +405,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
cov = factor.get_cov_var()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list())
@ -411,6 +415,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([1., 2.], name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 2)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -423,7 +428,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
factor.instantiate_cov_variables()
cov = factor.get_cov_var()
self.assertEqual(cov.shape.as_list(), [vocab_size])
@ -431,7 +437,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
factor.instantiate_cov_variables()
cov_update_op = factor.make_covariance_update_op(0.0)
with self.test_session() as sess:
@ -450,6 +457,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual(final_shape, cov.get_shape().as_list())
@ -467,6 +475,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -477,6 +486,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor((tensor,))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -491,6 +501,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 3, 4), 3, 2, has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
factor.get_cov().get_shape().as_list())
@ -500,6 +511,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
factor.get_cov().get_shape().as_list())
@ -510,6 +522,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
@ -522,6 +535,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
factor = ff.ConvInputKroneckerFactor(
tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -533,8 +547,9 @@ class ConvInputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.constant(
np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1],
'SAME')
factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1),
[1, 1, 1, 1], 'SAME')
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -548,6 +563,7 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
factor = ff.ConvOutputKroneckerFactor((tensor,))
factor.instantiate_cov_variables()
self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
def testConvOutputKroneckerFactorInitFloat64(self):
@ -556,6 +572,7 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
factor = ff.ConvOutputKroneckerFactor((tensor,))
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([5, 5], cov.get_shape().as_list())
@ -565,13 +582,14 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
with self.assertRaises(IndexError):
ff.ConvOutputKroneckerFactor(tensor)
ff.ConvOutputKroneckerFactor((tensor,))
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -586,6 +604,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
def testFullyConnectedMultiKFInitFloat64(self):
@ -595,6 +614,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([3, 3], cov.get_shape().as_list())
@ -605,6 +625,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
@ -616,6 +637,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
tensor_list = [tensor]
factor = ff.FullyConnectedMultiKF((tensor_list,))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))

View File

@ -237,16 +237,16 @@ class LayerCollectionTest(test.TestCase):
# Create a new loss function by name.
lc.register_categorical_predictive_distribution(logits, name='loss1')
self.assertEqual(1, len(lc.losses))
self.assertEqual(1, len(lc.towers_by_loss))
# Add logits to same loss function.
lc.register_categorical_predictive_distribution(
logits, name='loss1', reuse=True)
self.assertEqual(1, len(lc.losses))
self.assertEqual(1, len(lc.towers_by_loss))
# Add another new loss function.
lc.register_categorical_predictive_distribution(logits, name='loss2')
self.assertEqual(2, len(lc.losses))
self.assertEqual(2, len(lc.towers_by_loss))
def testLossFunctionWithoutName(self):
"""Ensure loss functions get unique names if 'name' not specified."""
@ -298,13 +298,9 @@ class LayerCollectionTest(test.TestCase):
name='loss1',
reuse=layer_collection.VARIABLE_SCOPE)
self.assertEqual(len(lc.losses), 1)
loss = lc.losses[0]
self.assertEqual(len(lc.towers_by_loss), 1)
# Three successful registrations.
self.assertEqual(loss.params.shape.as_list(),
[3 * batch_size, output_size])
self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size])
self.assertEqual(len(lc.towers_by_loss[0]), 3)
def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
with ops.Graph().as_default():
@ -479,17 +475,6 @@ class LayerCollectionTest(test.TestCase):
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
def testGetUseCountMap(self):
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {
'a': MockFisherBlock(),
('a', 'c'): MockFisherBlock(),
('b', 'c'): MockFisherBlock()
}
use_count_map = lc.get_use_count_map()
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', shape=())
y = variable_scope.get_variable('y', shape=())

View File

@ -24,7 +24,6 @@ from tensorflow.contrib.kfac.python.ops import loss_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@ -97,22 +96,6 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
def testMultiMinibatchRegistration(self):
"""Ensure this loss function supports registering multiple minibatches."""
with ops.Graph().as_default():
tower_logits = []
loss = None
num_towers = 5
for _ in range(num_towers):
logits = random_ops.random_uniform(shape=[2, 3])
tower_logits.append(logits)
if loss is None:
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
else:
loss.register_additional_minibatch(logits)
self.assertListEqual(loss.input_minibatches, tower_logits)
self.assertEqual(loss.num_registered_minibatches, num_towers)
def testMultiplyFisherSingleVector(self):
with ops.Graph().as_default(), self.test_session() as sess:
logits = np.array([1., 2., 3.])
@ -203,23 +186,5 @@ class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
def testMultiMinibatchRegistration(self):
"""Ensure this loss function supports registering multiple minibatches."""
with ops.Graph().as_default():
tower_logits = []
loss = None
num_towers = 5
for _ in range(num_towers):
logits = random_ops.random_uniform(shape=[2, 3])
tower_logits.append(logits)
if loss is None:
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
logits)
else:
loss.register_additional_minibatch(logits)
self.assertListEqual(loss.input_minibatches, tower_logits)
self.assertEqual(loss.num_registered_minibatches, num_towers)
if __name__ == "__main__":
test.main()

View File

@ -27,6 +27,7 @@ from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@ -65,6 +66,13 @@ class _DeviceContextGenerator(object):
yield
def _make_thunk_on_device(func, device):
def thunk():
with tf_ops.device(device):
return func()
return thunk
class FisherEstimator(object):
"""Fisher estimator class supporting various approximations of the Fisher.
@ -83,26 +91,35 @@ class FisherEstimator(object):
"""
def __init__(self,
damping_fn,
variables,
cov_ema_decay,
damping,
layer_collection,
exps=(-1,),
estimation_mode="gradients",
colocate_gradients_with_ops=True,
cov_devices=None,
inv_devices=None):
name="FisherEstimator"):
"""Create a FisherEstimator object.
Args:
damping_fn: Function, accepts no arguments and returns damping value.
variables: A list of the variables for which to estimate the Fisher. This
must match the variables registered in layer_collection (if it is not
None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: float. The damping factor used to stabilize training due to
errors in the local approximation with the Fisher information matrix,
and to regularize the update direction by making it closer to the
gradient. (Higher damping means the update looks more like a standard
gradient update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher
blocks, kronecker factors, and losses associated with the
graph.
exps: List of floats or ints. These represent the different matrix
powers of the approximate Fisher that the FisherEstimator will be able
to multiply vectors by. If the user asks for a matrix power other
one of these (or 1, which is always supported), there will be a
failure. (Default: (-1,))
estimation_mode: The type of estimator to use for the Fishers. Can be
'gradients', 'empirical', 'curvature_prop', or 'exact'.
(Default: 'gradients'). 'gradients' is the basic estimation approach
@ -121,19 +138,15 @@ class FisherEstimator(object):
equal to the output dimension, roughly speaking.
colocate_gradients_with_ops: Whether we should request gradients be
colocated with their respective ops. (Default: True)
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
name: A string. A name given to this estimator, which is added to the
variable scope when constructing variables and ops.
(Default: "FisherEstimator")
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
self._damping_fn = damping_fn
self._cov_ema_decay = cov_ema_decay
self._variables = variables
self._cov_ema_decay = cov_ema_decay
self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
self._layers.create_subgraph()
@ -146,30 +159,13 @@ class FisherEstimator(object):
}
self._colocate_gradients_with_ops = colocate_gradients_with_ops
# TODO(b/70674513): Factor device placement outside of this class.
self._cov_device_context_generator = _DeviceContextGenerator(cov_devices)
if inv_devices == cov_devices:
self._inv_device_context_generator = self._cov_device_context_generator
else:
self._inv_device_context_generator = _DeviceContextGenerator(inv_devices)
self._made_vars = False
self._exps = exps
self._name = name
self._instantiate_factors()
self.cov_update_thunks = [
self._create_cov_update_thunk(factor)
for factor in self._layers.get_factors()
]
self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks]
self.cov_update_op = control_flow_ops.group(
self.cov_update_ops, name="cov_update_op")
self.inv_update_thunks = [
self._create_inv_update_thunk(factor)
for factor in self._layers.get_factors()
]
self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks]
self.inv_update_op = control_flow_ops.group(
self.inv_update_ops, name="inv_update_op")
self._register_matrix_functions()
@property
def variables(self):
@ -177,7 +173,21 @@ class FisherEstimator(object):
@property
def damping(self):
return self._damping_fn()
return self._damping
@property
def blocks(self):
"""All registered FisherBlocks."""
return self._layers.get_blocks()
@property
def factors(self):
"""All registered FisherFactors."""
return self._layers.get_factors()
@property
def name(self):
return self._name
def _apply_transformation(self, vecs_and_vars, transform):
"""Applies an block-wise transformation to the corresponding vectors.
@ -212,9 +222,7 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
return self._apply_transformation(vecs_and_vars,
lambda fb, vec: fb.multiply_inverse(vec))
return self.multiply_matpower(-1, vecs_and_vars)
def multiply(self, vecs_and_vars):
"""Multiplies the vectors by the corresponding (damped) blocks.
@ -226,9 +234,22 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
return self.multiply_matpower(1, vecs_and_vars)
return self._apply_transformation(vecs_and_vars,
lambda fb, vec: fb.multiply(vec))
def multiply_matpower(self, exp, vecs_and_vars):
"""Multiplies the vecs by the corresponding matrix powers of the blocks.
Args:
exp: A float representing the power to raise the blocks by before
multiplying it by the vector.
vecs_and_vars: List of (vector, variable) pairs.
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
return self._apply_transformation(vecs_and_vars, fcn)
def _instantiate_factors(self):
"""Instantiates FisherFactors' variables.
@ -236,9 +257,9 @@ class FisherEstimator(object):
Raises:
ValueError: If estimation_mode was improperly specified at construction.
"""
fisher_blocks_list = self._layers.get_blocks()
blocks = self.blocks
tensors_to_compute_grads = [
fb.tensors_to_compute_grads() for fb in fisher_blocks_list
block.tensors_to_compute_grads() for block in blocks
]
try:
@ -248,45 +269,275 @@ class FisherEstimator(object):
raise ValueError("Unrecognized value {} for estimation_mode.".format(
self._estimation_mode))
# TODO(b/68033310): This loop round-robins the "concat" operations which
# gather the inputs for the cov_updates. In future, we might do these
# computations locally then communicate the results, which would require a
# modification to this code.
for grads_list, fb in zip(grads_lists, fisher_blocks_list):
with self._cov_device_context_generator():
fb.instantiate_factors(grads_list, self.damping)
for grads_list, block in zip(grads_lists, blocks):
block.instantiate_factors(grads_list, self.damping)
def _create_cov_update_thunk(self, factor):
def _check_vars_unmade_and_set_made_flag(self):
if self._made_vars:
raise Exception("Already made variables.")
self._made_vars = True
def made_vars(self):
return self._made_vars
def _register_matrix_functions(self):
for exp in self._exps:
for block in self.blocks:
block.register_matpower(exp)
def make_ops_and_vars(self, scope=None):
"""Make ops and vars with no specific device placement.
See make_ops_and_vars_round_robin for further details.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All variables will be created,
and all ops will execute, inside of a variable scope of the given
name. (Default: None)
Returns:
cov_update_ops: List of ops that compute the cov updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
cov_update_op: cov_update_ops grouped into a single op.
inv_update_ops: List of ops that compute the inv updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
inv_update_op: inv_update_ops grouped into a single op.
cov_update_thunks: Thunks that make the ops in cov_update_ops.
inv_update_thunks: Thunks that make the ops in inv_update_ops.
"""
return self.make_ops_and_vars_round_robin(scope=scope)
# TODO(b/70674513): Factor device placement outside of this class.
def make_ops_and_vars_round_robin(self, scope=None, cov_devices=None,
inv_devices=None):
"""Make ops and vars with a round-robin device placement strategy.
For each factor, all of that factor's cov variables and their associated
update ops will be placed on a particular device. A new device is chosen
for each factor by cycling through list of devices in the cov_devices
argument. If cov_devices is None then no explicit device placement occurs.
An analogous strategy is followed for inverse update ops, with the list of
devices being given by the inv_devices argument.
Inverse variables on the other hand are not placed on any specific device
(they will just use the current the device placement context, whatever
that happens to be). The idea is that the inverse variable belong where
they will be accessed most often, which is the device that actually applies
the preconditioner to the gradient. The user will be responsible for setting
the device context for this.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All variables will be created,
and all ops will execute, inside of a variable scope of the given
name. (Default: None)
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
Returns:
cov_update_ops: List of ops that compute the cov updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
cov_update_op: cov_update_ops grouped into a single op.
inv_update_ops: List of ops that compute the inv updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
inv_update_op: inv_update_ops grouped into a single op.
cov_update_thunks: Thunks that make the ops in cov_update_ops.
inv_update_thunks: Thunks that make the ops in inv_update_ops.
"""
(cov_update_thunks,
inv_update_thunks) = self.make_vars_and_create_op_thunks_round_robin(
scope=scope,
cov_devices=cov_devices,
inv_devices=inv_devices)
cov_update_ops = [thunk() for thunk in cov_update_thunks]
inv_update_ops = [thunk() for thunk in inv_update_thunks]
scope = self.name if scope is None else scope
with variable_scope.variable_scope(scope):
cov_update_op = control_flow_ops.group(cov_update_ops,
name="cov_update_op")
inv_update_op = control_flow_ops.group(inv_update_ops,
name="inv_update_op")
return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
cov_update_thunks, inv_update_thunks)
def make_vars_and_create_op_thunks_round_robin(self,
scope=None,
cov_devices=None,
inv_devices=None):
"""Make vars and create op thunks w/ a round-robin device placement strat.
For each factor, all of that factor's cov variables and their associated
update ops will be placed on a particular device. A new device is chosen
for each factor by cycling through list of devices in the cov_devices
argument. If cov_devices is None then no explicit device placement occurs.
An analogous strategy is followed for inverse update ops, with the list of
devices being given by the inv_devices argument.
Inverse variables on the other hand are not placed on any specific device
(they will just use the current the device placement context, whatever
that happens to be). The idea is that the inverse variable belong where
they will be accessed most often, which is the device that actually applies
the preconditioner to the gradient. The user will be responsible for setting
the device context for this.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All variables will be created,
and all thunks will execute, inside of a variable scope of the given
name. (Default: None)
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
Returns:
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
"""
(cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
if cov_devices:
cov_update_thunks = []
for cov_variable_thunk, cov_update_thunk, device in zip(
cov_variable_thunks_raw, cov_update_thunks_raw,
itertools.cycle(cov_devices)):
with tf_ops.device(device):
cov_variable_thunk()
cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
device))
else:
for cov_variable_thunk in cov_variable_thunks_raw:
cov_variable_thunk()
cov_update_thunks = cov_update_thunks_raw
for inv_variable_thunk in inv_variable_thunks_raw:
inv_variable_thunk()
if inv_devices:
inv_update_thunks = []
for inv_update_thunk, device in zip(inv_update_thunks_raw,
itertools.cycle(inv_devices)):
inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
device))
else:
inv_update_thunks = inv_update_thunks_raw
return cov_update_thunks, inv_update_thunks
def create_ops_and_vars_thunks(self, scope=None):
"""Create thunks that make the ops and vars on demand.
This function returns 4 lists of thunks: cov_variable_thunks,
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
The length of each list is the number of factors and the i-th element of
each list corresponds to the i-th factor (given by the "factors" property).
Note that the execution of these thunks must happen in a certain
partial order. The i-th element of cov_variable_thunks must execute
before the i-th element of cov_update_thunks (and also the i-th element
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
must execute before the i-th element of inv_update_thunks.
TL;DR (oversimplified): Execute the thunks according to the order that
they are returned.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All thunks will execute inside
of a variable scope of the given name. (Default: None)
Returns:
cov_variable_thunks: A list of thunks that make the cov variables.
cov_update_thunks: A list of thunks that make the cov update ops.
inv_variable_thunks: A list of thunks that make the inv variables.
inv_update_thunks: A list of thunks that make the inv update ops.
"""
self._check_vars_unmade_and_set_made_flag()
scope = self.name if scope is None else scope
cov_variable_thunks = [
self._create_cov_variable_thunk(factor, scope)
for factor in self.factors
]
cov_update_thunks = [
self._create_cov_update_thunk(factor, scope) for factor in self.factors
]
inv_variable_thunks = [
self._create_inv_variable_thunk(factor, scope)
for factor in self.factors
]
inv_update_thunks = [
self._create_inv_update_thunk(factor, scope) for factor in self.factors
]
return (cov_variable_thunks, cov_update_thunks,
inv_variable_thunks, inv_update_thunks)
def _create_cov_variable_thunk(self, factor, scope):
"""Constructs a covariance variable thunk for a single FisherFactor."""
def thunk():
with variable_scope.variable_scope(scope):
return factor.instantiate_cov_variables()
return thunk
def _create_cov_update_thunk(self, factor, scope):
"""Constructs a covariance update thunk for a single FisherFactor."""
def thunk():
with tf_ops.name_scope(
"create_cov_update_thunk", values=[self._cov_ema_decay]):
with variable_scope.variable_scope(scope):
return factor.make_covariance_update_op(self._cov_ema_decay)
return thunk
def _create_inv_update_thunk(self, factor):
def _create_inv_variable_thunk(self, factor, scope):
"""Constructs a inverse variable thunk for a single FisherFactor."""
def thunk():
with variable_scope.variable_scope(scope):
return factor.instantiate_inv_variables()
return thunk
def _create_inv_update_thunk(self, factor, scope):
"""Constructs an inverse update thunk for a single FisherFactor."""
def thunk():
with tf_ops.name_scope("create_inv_update_thunk"):
with self._inv_device_context_generator():
return control_flow_ops.group(factor.make_inverse_update_ops())
with variable_scope.variable_scope(scope):
return control_flow_ops.group(factor.make_inverse_update_ops())
return thunk
def _get_grads_lists_gradients(self, tensors):
# Passing in a list of loss values is better than passing in the sum as
# the latter creates unnessesary ops on the default device
grads_flat = gradients_impl.gradients(
self._layers.total_sampled_loss(),
self._layers.eval_losses_on_samples(),
nest.flatten(tensors),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
return tuple((grad,) for grad in grads_all)
def _get_grads_lists_empirical(self, tensors):
# Passing in a list of loss values is better than passing in the sum as
# the latter creates unnessesary ops on the default device
grads_flat = gradients_impl.gradients(
self._layers.total_loss(),
self._layers.eval_losses(),
nest.flatten(tensors),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
@ -295,9 +546,10 @@ class FisherEstimator(object):
def _get_transformed_random_signs(self):
transformed_random_signs = []
for loss in self._layers.losses:
transformed_random_signs.append(
loss.multiply_fisher_factor(
utils.generate_random_signs(loss.fisher_factor_inner_shape)))
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
transformed_random_signs.append(
loss.multiply_fisher_factor(
utils.generate_random_signs(loss.fisher_factor_inner_shape)))
return transformed_random_signs
def _get_grads_lists_curvature_prop(self, tensors):
@ -316,13 +568,14 @@ class FisherEstimator(object):
# Loop over all coordinates of all losses.
grads_all = []
for loss in self._layers.losses:
for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
index)
grads_flat = gradients_impl.gradients(
loss.inputs,
nest.flatten(tensors),
grad_ys=transformed_one_hot,
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
index)
grads_flat = gradients_impl.gradients(
loss.inputs,
nest.flatten(tensors),
grad_ys=transformed_one_hot,
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
return zip(*grads_all)

View File

@ -121,12 +121,44 @@ def compute_pi_adjusted_damping(left_cov, right_cov, damping):
return (damping, damping)
class PackagedFunc(object):
"""A Python thunk with a stable ID.
Enables stable names for lambdas.
"""
def __init__(self, func, func_id):
"""Initializes PackagedFunc.
Args:
func: a zero-arg Python function.
func_id: a hashable, function that produces a hashable, or a list/tuple
thereof.
"""
self._func = func
func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
self._func_id = func_id
def __call__(self):
return self._func()
@property
def func_id(self):
"""A hashable identifier for this function."""
return tuple(elt() if callable(elt) else elt for elt in self._func_id)
def _package_func(func, func_id):
return PackagedFunc(func, func_id)
@six.add_metaclass(abc.ABCMeta)
class FisherBlock(object):
"""Abstract base class for objects modeling approximate Fisher matrix blocks.
Subclasses must implement multiply_inverse(), instantiate_factors(), and
tensors_to_compute_grads() methods.
Subclasses must implement register_matpower, multiply_matpower,
instantiate_factors, tensors_to_compute_grads, and num_registered_minibatches
methods.
"""
def __init__(self, layer_collection):
@ -145,6 +177,32 @@ class FisherBlock(object):
pass
@abc.abstractmethod
def register_matpower(self, exp):
"""Registers a matrix power to be computed by the block.
Args:
exp: A float representing the power to raise the block by.
"""
pass
def register_inverse(self):
"""Registers a matrix inverse to be computed by the block."""
self.register_matpower(-1)
@abc.abstractmethod
def multiply_matpower(self, vector, exp):
"""Multiplies the vector by the (damped) matrix-power of the block.
Args:
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
exp: A float representing the power to raise the block by before
multiplying it by the vector.
Returns:
The vector left-multiplied by the (damped) matrix-power of the block.
"""
pass
def multiply_inverse(self, vector):
"""Multiplies the vector by the (damped) inverse of the block.
@ -154,9 +212,8 @@ class FisherBlock(object):
Returns:
The vector left-multiplied by the (damped) inverse of the block.
"""
pass
return self.multiply_matpower(vector, -1)
@abc.abstractmethod
def multiply(self, vector):
"""Multiplies the vector by the (damped) block.
@ -166,7 +223,7 @@ class FisherBlock(object):
Returns:
The vector left-multiplied by the (damped) block.
"""
pass
return self.multiply_matpower(vector, 1)
@abc.abstractmethod
def tensors_to_compute_grads(self):
@ -207,21 +264,18 @@ class FullFB(FisherBlock):
super(FullFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
self._damping = damping
self._damping_func = _package_func(lambda: damping, (damping,))
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullFactor, (grads_list, self._batch_size))
self._factor.register_damped_inverse(damping)
def multiply_inverse(self, vector):
vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply_inverse(
vector_flat, self._damping)
return utils.column_to_tensors(vector, out_flat)
def register_matpower(self, exp):
self._factor.register_matpower(exp, self._damping_func)
def multiply(self, vector):
def multiply_matpower(self, vector, exp):
vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply(
vector_flat, self._damping)
out_flat = self._factor.left_multiply_matpower(
vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
@ -271,22 +325,20 @@ class NaiveDiagonalFB(FisherBlock):
super(NaiveDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
self._damping = damping
self._damping_func = _package_func(lambda: damping, (damping,))
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
def multiply_inverse(self, vector):
vector_flat = utils.tensors_to_column(vector)
print("vector_flat: %s" % vector_flat)
out_flat = self._factor.left_multiply_inverse(
vector_flat, self._damping)
print("out_flat: %s" % out_flat)
return utils.column_to_tensors(vector, out_flat)
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply(self, vector):
def multiply_matpower(self, vector, exp):
vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply(
vector_flat, self._damping)
out_flat = self._factor.left_multiply_matpower(
vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
@ -312,7 +364,89 @@ class NaiveDiagonalFB(FisherBlock):
return math_ops.reduce_sum(self._batch_sizes)
class FullyConnectedDiagonalFB(FisherBlock):
class InputOutputMultiMinibatch(object):
"""Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
def __init__(self, *args, **kwargs):
self.__inputs = []
self.__outputs = []
super(InputOutputMultiMinibatch, self).__init__(*args, **kwargs)
def tensors_to_compute_grads(self):
"""Tensors to compute derivative of loss with respect to."""
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
result = len(self._inputs)
assert result == len(self._outputs)
return result
@property
def _inputs(self):
return self.__inputs
@property
def _outputs(self):
return self.__outputs
def _package_minibatches(self, grads_list):
"""Constructs PartitionedTensor for inputs, grads_list.
The purpose of this method is to package up the towers/minibatch dimension
of these arrays into PartitionedTensor objects.
Args:
grads_list: 2-D list of Tensors. First index is for source, second
index for tower.
Returns:
inputs: PartitionedTensor.
grads_list: Tuple of PartitionedTensors, one per source.
"""
inputs = utils.PartitionedTensor(self._inputs)
grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list)
return inputs, grads_list
def _package_minibatches_multi(self, grads_list):
"""Constructs PartitionedTensors for inputs, grads_list.
The purpose of this method is to package up the towers/minibatch dimension
of these arrays into PartitionedTensor objects.
This version of this function is for use with FisherBlocks that deal with
multiple uses or time-steps. One PartitionedTensor is created for each
use/time-step.
Args:
grads_list: 3-D tuple of Tensors. First index is for source, second
index is for tower, third is for use/time-step.
Returns:
inputs: A tuple of PartitionedTensor's, one per use/time-step.
grads_list: 2-D tuple of PartitionedTensors. First index is for source,
second is for use/time-step.
"""
# self._inputs is a 2-D tuple. First index is tower/mini-batch, second is
# use/time-step.
inputs = self._inputs
num_uses = len(inputs[0])
assert all(len(input_) == num_uses for input_ in inputs)
assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
grads_list = tuple(tuple(utils.PartitionedTensor(grad)
for grad in zip(*grads)) for grads in grads_list)
return inputs, grads_list
class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a fully
@ -344,79 +478,45 @@ class FullyConnectedDiagonalFB(FisherBlock):
has_bias: Whether the component Kronecker factors have an additive bias.
(Default: False)
"""
self._inputs = []
self._outputs = []
self._has_bias = has_bias
super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
inputs, grads_list = self._package_minibatches(grads_list)
self._damping = damping
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedDiagonalFactor,
(inputs, grads_list, self._has_bias))
def multiply_inverse(self, vector):
"""Approximate damped inverse Fisher-vector product.
self._damping_func = _package_func(lambda: damping, (damping,))
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
"""Multiplies the vector by the (damped) matrix-power of the block.
Args:
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
[input_size, output_size] corresponding to layer's weights. If not, a
2-tuple of the former and a Tensor of shape [output_size] corresponding
to the layer's bias.
exp: A scalar representing the power to raise the block before multiplying
it by the vector.
Returns:
Tensor of the same shape, corresponding to the inverse Fisher-vector
product.
The vector left-multiplied by the (damped) matrix-power of the block.
"""
reshaped_vec = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply_inverse(
reshaped_vec, self._damping)
reshaped_out = self._factor.left_multiply_matpower(
reshaped_vec, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply(self, vector):
"""Approximate damped Fisher-vector product.
Args:
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
[input_size, output_size] corresponding to layer's weights. If not, a
2-tuple of the former and a Tensor of shape [output_size] corresponding
to the layer's bias.
Returns:
Tensor of the same shape, corresponding to the Fisher-vector product.
"""
reshaped_vec = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply(
reshaped_vec, self._damping)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def tensors_to_compute_grads(self):
"""Tensors to compute derivative of loss with respect to."""
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
matrix-multiply.
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
result = len(self._inputs)
assert result == len(self._outputs)
return result
class ConvDiagonalFB(FisherBlock):
class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
@ -454,8 +554,6 @@ class ConvDiagonalFB(FisherBlock):
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (e.g. "SAME").
"""
self._inputs = []
self._outputs = []
self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
self._has_bias = isinstance(params, (tuple, list))
@ -466,55 +564,38 @@ class ConvDiagonalFB(FisherBlock):
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
# Concatenate inputs, grads_list into single Tensors.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
# Infer number of locations upon which convolution is applied.
inputs_shape = tuple(inputs.shape.as_list())
inputs_shape = tuple(self._inputs[0].shape.as_list())
self._num_locations = (
inputs_shape[1] * inputs_shape[2] //
(self._strides[1] * self._strides[2]))
self._damping = (self._num_locations
* normalize_damping(damping, self._num_locations))
inputs, grads_list = self._package_minibatches(grads_list)
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
(inputs, grads_list, self._filter_shape, self._strides, self._padding,
self._has_bias))
(inputs, grads_list, self._filter_shape, self._strides,
self._padding, self._has_bias))
def multiply_inverse(self, vector):
def damping_func():
return self._num_locations * normalize_damping(damping,
self._num_locations)
damping_id = (self._num_locations, "mult", "normalize_damping", damping,
self._num_locations)
self._damping_func = _package_func(damping_func, damping_id)
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
reshaped_vect = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply_inverse(
reshaped_vect, self._damping)
reshaped_out = self._factor.left_multiply_matpower(
reshaped_vect, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply(self, vector):
reshaped_vect = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply(
reshaped_vect, self._damping)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
the convolution.
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return len(self._inputs)
class KroneckerProductFB(FisherBlock):
"""A base class for FisherBlocks with separate input and output factors.
@ -523,22 +604,40 @@ class KroneckerProductFB(FisherBlock):
output factors.
"""
def _register_damped_input_and_output_inverses(self, damping):
"""Registers damped inverses for both the input and output factors.
def __init__(self, layer_collection):
super(KroneckerProductFB, self).__init__(layer_collection)
Sets the instance members _input_damping and _output_damping. Requires the
instance members _input_factor and _output_factor.
def _setup_damping(self, damping, normalization=None):
"""Makes functions that compute the damping values for both factors."""
def compute_damping():
if normalization is not None:
maybe_normalized_damping = normalize_damping(damping, normalization)
else:
maybe_normalized_damping = damping
Args:
damping: The base damping factor (float or Tensor) for the damped inverse.
"""
self._input_damping, self._output_damping = compute_pi_adjusted_damping(
self._input_factor.get_cov(),
self._output_factor.get_cov(),
damping**0.5)
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
self._output_factor.get_cov(),
maybe_normalized_damping**0.5)
self._input_factor.register_damped_inverse(self._input_damping)
self._output_factor.register_damped_inverse(self._output_damping)
if normalization is not None:
damping_id = ("compute_pi_adjusted_damping",
"cov", self._input_factor.name,
"cov", self._output_factor.name,
"normalize_damping", damping, normalization, "power", 0.5)
else:
damping_id = ("compute_pi_adjusted_damping",
"cov", self._input_factor.name,
"cov", self._output_factor.name,
damping, "power", 0.5)
self._input_damping_func = _package_func(lambda: compute_damping()[0],
damping_id + ("ref", 0))
self._output_damping_func = _package_func(lambda: compute_damping()[1],
damping_id + ("ref", 1))
def register_matpower(self, exp):
self._input_factor.register_matpower(exp, self._input_damping_func)
self._output_factor.register_matpower(exp, self._output_damping_func)
@property
def _renorm_coeff(self):
@ -552,28 +651,15 @@ class KroneckerProductFB(FisherBlock):
"""
return 1.0
def multiply_inverse(self, vector):
def multiply_matpower(self, vector, exp):
reshaped_vector = utils.layer_params_to_mat2d(vector)
reshaped_out = self._output_factor.right_multiply_inverse(
reshaped_vector,
self._output_damping)
reshaped_out = self._input_factor.left_multiply_inverse(
reshaped_out, self._input_damping)
if self._renorm_coeff != 1.0:
reshaped_out /= math_ops.cast(
self._renorm_coeff, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply(self, vector):
reshaped_vector = utils.layer_params_to_mat2d(vector)
reshaped_out = self._output_factor.right_multiply(
reshaped_vector,
self._output_damping)
reshaped_out = self._input_factor.left_multiply(
reshaped_out, self._input_damping)
reshaped_out = self._output_factor.right_multiply_matpower(
reshaped_vector, exp, self._output_damping_func)
reshaped_out = self._input_factor.left_multiply_matpower(
reshaped_out, exp, self._input_damping_func)
if self._renorm_coeff != 1.0:
reshaped_out *= math_ops.cast(
self._renorm_coeff, dtype=reshaped_out.dtype)
self._renorm_coeff**exp, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def full_fisher_block(self):
@ -590,7 +676,7 @@ class KroneckerProductFB(FisherBlock):
right_factor)
class EmbeddingKFACFB(KroneckerProductFB):
class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for embedding layers.
This FisherBlock is similar to EmbeddingKFACFB, except that its
@ -608,8 +694,6 @@ class EmbeddingKFACFB(KroneckerProductFB):
Fisher information matrix to which this FisherBlock belongs.
vocab_size: int. Size of vocabulary for this embedding layer.
"""
self._inputs = []
self._outputs = []
self._vocab_size = vocab_size
super(EmbeddingKFACFB, self).__init__(layer_collection)
@ -624,41 +708,18 @@ class EmbeddingKFACFB(KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
# TODO(b/68033310): Validate which of,
# (1) summing on a single device (as below), or
# (2) on each device in isolation and aggregating
# is faster.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
inputs, grads_list = self._package_minibatches(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.EmbeddingInputKroneckerFactor, #
((inputs,), self._vocab_size))
(inputs, self._vocab_size))
self._output_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
(grads_list,))
self._register_damped_input_and_output_inverses(damping)
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
matrix-multiply.
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return len(self._inputs)
self._setup_damping(damping)
class FullyConnectedKFACBasicFB(KroneckerProductFB):
class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for fully-connected (dense) layers.
This uses the Kronecker-factorized approximation from the original
@ -674,8 +735,6 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
has_bias: Whether the component Kronecker factors have an additive bias.
(Default: False)
"""
self._inputs = []
self._outputs = []
self._has_bias = has_bias
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
@ -690,12 +749,7 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
# TODO(b/68033310): Validate which of,
# (1) summing on a single device (as below), or
# (2) on each device in isolation and aggregating
# is faster.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
inputs, grads_list = self._package_minibatches(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
@ -703,28 +757,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor( #
fisher_factors.FullyConnectedKroneckerFactor, #
(grads_list,))
self._register_damped_input_and_output_inverses(damping)
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
matrix-multiply.
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return len(self._inputs)
self._setup_damping(damping)
class ConvKFCBasicFB(KroneckerProductFB):
class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
@ -761,8 +797,6 @@ class ConvKFCBasicFB(KroneckerProductFB):
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
"""
self._inputs = []
self._outputs = []
self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
self._has_bias = isinstance(params, (tuple, list))
@ -773,17 +807,12 @@ class ConvKFCBasicFB(KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
# TODO(b/68033310): Validate which of,
# (1) summing on a single device (as below), or
# (2) on each device in isolation and aggregating
# is faster.
inputs = _concat_along_batch_dim(self._inputs)
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
# Infer number of locations upon which convolution is applied.
self._num_locations = num_conv_locations(inputs.shape.as_list(),
self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
self._strides)
inputs, grads_list = self._package_minibatches(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(inputs, self._filter_shape, self._strides, self._padding,
@ -791,60 +820,12 @@ class ConvKFCBasicFB(KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
damping = normalize_damping(damping, self._num_locations)
self._register_damped_input_and_output_inverses(damping)
self._damping = damping
self._setup_damping(damping, normalization=self._num_locations)
@property
def _renorm_coeff(self):
return self._num_locations
def tensors_to_compute_grads(self):
return self._outputs
def register_additional_minibatch(self, inputs, outputs):
"""Registers an additional minibatch to the FisherBlock.
Args:
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
the convolution.
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
preactivations.
"""
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return len(self._inputs)
def _concat_along_batch_dim(tensor_list):
"""Concatenate tensors along batch (first) dimension.
Args:
tensor_list: list of Tensors or list of tuples of Tensors.
Returns:
Tensor or tuple of Tensors.
Raises:
ValueError: If 'tensor_list' is empty.
"""
if not tensor_list:
raise ValueError(
"Cannot concatenate Tensors if there are no Tensors to concatenate.")
if isinstance(tensor_list[0], (tuple, list)):
# [(tensor1a, tensor1b),
# (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b)
return tuple(
array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list))
else:
# [tensor1, tensor2] --> tensor
return array_ops.concat(tensor_list, axis=0)
def num_conv_locations(input_shape, strides):
"""Returns the number of spatial locations a 2D Conv kernel is applied to.
@ -859,49 +840,35 @@ def num_conv_locations(input_shape, strides):
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
class FullyConnectedMultiIndepFB(KroneckerProductFB):
class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters.
"""
def __init__(self, layer_collection, inputs, outputs, has_bias=False):
def __init__(self, layer_collection, has_bias=False):
"""Creates a FullyConnectedMultiIndepFB block.
Args:
layer_collection: LayerCollection instance.
inputs: list or tuple of Tensors. Each Tensor has shape [batch_size,
inputs_size].
outputs: list or tuple of Tensors. Each Tensor has shape [batch_size,
outputs_size].
has_bias: bool. If True, estimates Fisher with respect to a bias
parameter as well as the layer's parameters.
"""
assert len(inputs) == len(outputs)
# We need to make sure inputs and outputs are tuples and not lists so that
# they get hashed by layer_collection.make_or_get_factor properly.
self._inputs = tuple(inputs)
self._outputs = tuple(outputs)
self._has_bias = has_bias
self._num_uses = len(inputs)
super(FullyConnectedMultiIndepFB, self).__init__(layer_collection)
@property
def num_registered_minibatches(self):
# TODO(b/69411207): Add support for registering additional minibatches.
return 1
def instantiate_factors(self, grads_list, damping):
self._num_uses = len(self._inputs[0])
inputs, grads_list = self._package_minibatches_multi(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF,
((self._inputs,), self._has_bias))
((inputs,), self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
damping = normalize_damping(damping, self._num_uses)
self._register_damped_input_and_output_inverses(damping)
self._setup_damping(damping, normalization=self._num_uses)
@property
def _renorm_coeff(self):
@ -910,9 +877,6 @@ class FullyConnectedMultiIndepFB(KroneckerProductFB):
def tensors_to_compute_grads(self):
return self._outputs
def num_inputs(self):
return len(self._inputs)
class SeriesFBApproximation(enum.IntEnum):
"""See FullyConnectedSeriesFB.__init__ for description and usage."""
@ -920,22 +884,20 @@ class SeriesFBApproximation(enum.IntEnum):
option2 = 2
class FullyConnectedSeriesFB(FisherBlock):
class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for fully-connected layers that share parameters across time.
See the following preprint for details:
https://openreview.net/pdf?id=HyMTkQZAb
See the end of the appendix of the paper for a pseudo-code of the
algorithm being implemented by multiply_inverse here. Note that we are
algorithm being implemented by multiply_matpower here. Note that we are
using pre-computed versions of certain matrix-matrix products to speed
things up. This is explicitly explained wherever it is done.
"""
def __init__(self,
layer_collection,
inputs,
outputs,
has_bias=False,
option=SeriesFBApproximation.option2):
"""Constructs a new `FullyConnectedSeriesFB`.
@ -943,10 +905,6 @@ class FullyConnectedSeriesFB(FisherBlock):
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
inputs: List of tensors of shape [batch_size, input_size].
Inputs to the layer.
outputs: List of tensors of shape [batch_size, input_size].
Outputs of the layer (before activations).
has_bias: Whether the layer includes a bias parameter.
option: A `SeriesFBApproximation` specifying the simplifying assumption
to be used in this block. `option1` approximates the cross-covariance
@ -955,48 +913,61 @@ class FullyConnectedSeriesFB(FisherBlock):
3.5 of the paper for more details.
"""
assert len(inputs) == len(outputs)
# We need to make sure inputs and outputs are tuples and not lists so that
# they get hashed by layer_collection.make_or_get_factor properly.
self._inputs = tuple(inputs)
self._outputs = tuple(outputs)
self._has_bias = has_bias
self._num_timesteps = len(inputs)
self._option = option
super(FullyConnectedSeriesFB, self).__init__(layer_collection)
@property
def num_registered_minibatches(self):
# TODO(b/69411207): Add support for registering additional minibatches.
return 1
def instantiate_factors(self, grads_list, damping):
self._num_timesteps = len(self._inputs[0])
inputs, grads_list = self._package_minibatches_multi(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias))
fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
self._input_factor.register_cov_dt1()
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, (grads_list,))
self._output_factor.register_cov_dt1()
damping = normalize_damping(damping, self._num_timesteps)
self._damping_input, self._damping_output = compute_pi_adjusted_damping(
self._input_factor.get_cov(),
self._output_factor.get_cov(),
damping**0.5)
def compute_damping():
normalized_damping = normalize_damping(damping, self._num_timesteps)
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
self._output_factor.get_cov(),
normalized_damping**0.5)
damping_id = ("compute_pi_adjusted_damping",
"cov", self._input_factor.name,
"cov", self._output_factor.name,
"normalize_damping",
damping, self._num_timesteps, "power", 0.5)
self._input_damping_func = _package_func(lambda: compute_damping()[0],
damping_id + ("ref", 0))
self._output_damping_func = _package_func(lambda: compute_damping()[1],
damping_id + ("ref", 1))
def register_matpower(self, exp):
if exp != -1:
raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
"multiplications.")
if self._option == SeriesFBApproximation.option1:
self._input_factor.register_option1quants(self._damping_input)
self._output_factor.register_option1quants(self._damping_output)
self._input_factor.register_option1quants(self._input_damping_func)
self._output_factor.register_option1quants(self._output_damping_func)
elif self._option == SeriesFBApproximation.option2:
self._input_factor.register_option2quants(self._damping_input)
self._output_factor.register_option2quants(self._damping_output)
self._input_factor.register_option2quants(self._input_damping_func)
self._output_factor.register_option2quants(self._output_damping_func)
else:
raise ValueError(
"Unrecognized FullyConnectedSeriesFB approximation: {}".format(
self._option))
def multiply_inverse(self, vector):
def multiply_matpower(self, vector, exp):
if exp != -1:
raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
"multiplications.")
# pylint: disable=invalid-name
Z = utils.layer_params_to_mat2d(vector)
@ -1008,8 +979,10 @@ class FullyConnectedSeriesFB(FisherBlock):
if self._option == SeriesFBApproximation.option1:
# Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G.
L_A, psi_A = self._input_factor.get_option1quants(self._damping_input)
L_G, psi_G = self._output_factor.get_option1quants(self._damping_output)
L_A, psi_A = self._input_factor.get_option1quants(
self._input_damping_func)
L_G, psi_G = self._output_factor.get_option1quants(
self._output_damping_func)
def gamma(x):
# We are assuming that each case has the same number of time-steps.
@ -1046,9 +1019,10 @@ class FullyConnectedSeriesFB(FisherBlock):
# Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1),
# and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G.
P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input)
P_A, K_A, mu_A = self._input_factor.get_option2quants(
self._input_damping_func)
P_G, K_G, mu_G = self._output_factor.get_option2quants(
self._damping_output)
self._output_damping_func)
# Our approach differs superficially from the pseudo-code in the paper
# in order to reduce the total number of matrix-matrix multiplies.
@ -1102,11 +1076,5 @@ class FullyConnectedSeriesFB(FisherBlock):
# pylint: enable=invalid-name
def multiply(self, vector):
raise NotImplementedError
def tensors_to_compute_grads(self):
return self._outputs
def num_inputs(self):
return len(self._inputs)

File diff suppressed because it is too large Load Diff

View File

@ -130,6 +130,8 @@ class LayerCollection(object):
fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
losses: a list of LossFunction objects. The loss to be optimized is their
sum.
loss_colocation_ops: ops to colocate loss function evaluations with. These
will typically be the inputs to the losses.
"""
def __init__(self,
@ -148,14 +150,21 @@ class LayerCollection(object):
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
APPROX_KRONECKER_SERIES_2_NAME)
self.loss_colocation_ops = {}
self._vars_to_uses = defaultdict(lambda: 0)
with variable_scope.variable_scope(None, default_name=name) as scope:
self._var_scope = scope.name
@property
def losses(self):
"""LossFunctions registered with this LayerCollection."""
return list(self._loss_dict.values())
"""Tuple of LossFunction objects registered with this LayerCollection."""
return nest.flatten(self.towers_by_loss)
@property
def towers_by_loss(self):
"""Tuple across losses of LossFunction objects registered to each tower."""
return tuple(tuple(lst) for lst in self._loss_dict.values())
@property
def registered_variables(self):
@ -290,23 +299,74 @@ class LayerCollection(object):
self.fisher_blocks[layer_key] = fisher_block
return fisher_block
def get_use_count_map(self):
"""Returns a dict of variables to their number of registrations."""
# TODO(b/70283403): Reimplement this in the old way, where each
# registration function would be responsible for incrementing the count.
# Also, this version has a bug: it won't do the right thing for generic
# registration for parameters that are shared. i.e. it won't set the use
# count to infinity.
vars_to_uses = defaultdict(int)
for key, block in six.iteritems(self.fisher_blocks):
n = (
block.num_inputs()*block.num_registered_minibatches if isinstance(
block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB))
else block.num_registered_minibatches)
key = utils.ensure_sequence(key)
for k in key:
vars_to_uses[k] += n
return vars_to_uses
def register_loss_function(self,
loss,
colocation_op,
base_name,
name=None,
reuse=VARIABLE_SCOPE):
"""Registers a LossFunction object.
Args:
loss: The LossFunction object.
colocation_op: The op to colocate the loss function's computations with.
base_name: The name to derive a new unique name from is the name argument
is None.
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
Raises:
ValueError: If reuse == True and name == None.
ValueError: If reuse == True and seed != None.
KeyError: If reuse == True and no existing LossFunction with 'name' found.
KeyError: If reuse == False and existing LossFunction with 'name' found.
"""
name = name or self._graph.unique_name(base_name)
if reuse == VARIABLE_SCOPE:
reuse = variable_scope.get_variable_scope().reuse
if reuse:
if name is None:
raise ValueError(
"If reuse is enabled, loss function's name must be set.")
loss_list = self._loss_dict.get(name, None)
if loss_list is None:
raise KeyError(
"Unable to find loss function named {}. Register a new loss "
"function with reuse=False.".format(name))
else:
if name in self._loss_dict:
raise KeyError(
"Loss function named {} already exists. Set reuse=True to append "
"another minibatch/tower.".format(name))
loss_list = []
self._loss_dict[name] = loss_list
loss_list.append(loss)
self.loss_colocation_ops[loss] = colocation_op
def _get_use_count_map(self):
"""Returns a dict mapping variables to their number of registrations."""
return self._vars_to_uses
def _add_uses(self, params, uses):
"""Register additional uses by params in the graph.
Args:
params: Variable or tuple of Variables. Parameters for a layer.
uses: int or float. Number of additional uses for these parameters.
"""
params = params if isinstance(params, (tuple, list)) else (params,)
for var in params:
self._vars_to_uses[var] += uses
def check_registration(self, variables):
"""Checks that all variable uses have been registered properly.
@ -324,7 +384,7 @@ class LayerCollection(object):
# Note that overlapping parameters (i.e. those that share variables) will
# be caught by layer_collection.LayerParametersDict during registration.
reg_use_map = self.get_use_count_map()
reg_use_map = self._get_use_count_map()
error_messages = []
@ -414,12 +474,27 @@ class LayerCollection(object):
inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
self._subgraph = utils.SubGraph(inputs_to_losses)
def eval_losses(self):
"""Return evaluated losses (colocated with inputs to losses)."""
evals = []
for loss in self.losses:
with ops.colocate_with(self.loss_colocation_ops[loss]):
evals.append(loss.evaluate())
return evals
def eval_losses_on_samples(self):
"""Return losses evaluated on samples (colocated with inputs to losses)."""
evals = []
for loss in self.losses:
with ops.colocate_with(self.loss_colocation_ops[loss]):
evals.append(loss.evaluate_on_sample())
return evals
def total_loss(self):
return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses))
return math_ops.add_n(self.eval_losses())
def total_sampled_loss(self):
return math_ops.add_n(
tuple(loss.evaluate_on_sample() for loss in self.losses))
return math_ops.add_n(self.eval_losses_on_samples())
def _get_linked_approx(self, params):
"""If params were linked, return their specified approximation."""
@ -469,6 +544,8 @@ class LayerCollection(object):
params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, 1)
def register_fully_connected(self,
params,
inputs,
@ -505,9 +582,12 @@ class LayerCollection(object):
block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
has_bias = isinstance(params, (tuple, list))
block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
block = self.register_block(params, block_type(self, has_bias=has_bias),
reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, 1)
def register_conv2d(self,
params,
strides,
@ -553,6 +633,8 @@ class LayerCollection(object):
params, block_type(self, params, strides, padding), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, 1)
def register_generic(self,
params,
batch_size,
@ -586,8 +668,10 @@ class LayerCollection(object):
block = self.register_block(params, block_type(self, params), reuse=reuse)
block.register_additional_minibatch(batch_size)
self._add_uses(params, float("inf"))
def register_fully_connected_multi(self, params, inputs, outputs,
approx=None):
approx=None, reuse=VARIABLE_SCOPE):
"""Register fully connected layers with shared parameters.
This can handle general fully-connected layers with shared parameters, but
@ -604,6 +688,9 @@ class LayerCollection(object):
[batch_size, output_size]. Outputs produced by layer. In the case of
RNNs, one Tensor per time step.
approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2".
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
Raises:
ValueError: For improper value to 'approx'.
@ -621,11 +708,14 @@ class LayerCollection(object):
raise ValueError("Bad value {} for approx.".format(approx))
block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx]
# For now we don't support multiple minibatches for this type of layer, so
# we set reuse=False
self.register_block(params,
block_type(self, inputs, outputs, has_bias=has_bias),
reuse=False)
block = self.register_block(params, block_type(self, has_bias=has_bias),
reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, len(inputs))
# TODO(b/74108452): change the loss registration functions names to refer
# to "loss functions" instead of distributions. Following naming convention
# of the loss function classes themselves.
def register_categorical_predictive_distribution(self,
logits,
@ -648,50 +738,20 @@ class LayerCollection(object):
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
Raises:
ValueError: If reuse == True and name == None.
ValueError: If reuse == True and seed != None.
KeyError: If reuse == True and no existing LossFunction with 'name' found.
KeyError: If reuse == False and existing LossFunction with 'name' found.
"""
name = name or self._graph.unique_name(
"register_categorical_predictive_distribution")
if reuse == VARIABLE_SCOPE:
reuse = variable_scope.get_variable_scope().reuse
if reuse:
if name is None:
raise ValueError(
"If reuse is enabled, loss function's name must be set.")
if seed is not None:
raise ValueError(
"Seed can only be specified at LossFunction instantiation.")
loss = self._loss_dict.get(name, None)
if loss is None:
raise KeyError(
"Unable to find loss function named {}. Create a new LossFunction "
"with reuse=False.".format(name))
loss.register_additional_minibatch(logits, targets=targets)
else:
if name in self._loss_dict:
raise KeyError(
"Loss function named {} already exists. Set reuse=True to append "
"another minibatch.".format(name))
loss = lf.CategoricalLogitsNegativeLogProbLoss(
logits, targets=targets, seed=seed)
self._loss_dict[name] = loss
loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
seed=seed)
self.register_loss_function(loss, logits,
"categorical_predictive_distribution",
name=name, reuse=reuse)
def register_normal_predictive_distribution(self,
mean,
var=0.5,
seed=None,
targets=None,
name=None):
name=None,
reuse=VARIABLE_SCOPE):
"""Registers a normal predictive distribution.
Args:
@ -708,21 +768,22 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
"""
name = name or self._graph.unique_name(
"register_normal_predictive_distribution")
if name in self._loss_dict:
raise NotImplementedError(
"Adding logits to an existing LossFunction not yet supported.")
loss = lf.NormalMeanNegativeLogProbLoss(
mean, var, targets=targets, seed=seed)
self._loss_dict[name] = loss
loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
seed=seed)
self.register_loss_function(loss, mean,
"normal_predictive_distribution",
name=name, reuse=reuse)
def register_multi_bernoulli_predictive_distribution(self,
logits,
seed=None,
targets=None,
name=None):
name=None,
reuse=VARIABLE_SCOPE):
"""Registers a multi-Bernoulli predictive distribution.
Args:
@ -735,15 +796,15 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
"""
name = name or self._graph.unique_name(
"register_multi_bernoulli_predictive_distribution")
if name in self._loss_dict:
raise NotImplementedError(
"Adding logits to an existing LossFunction not yet supported.")
loss = lf.MultiBernoulliNegativeLogProbLoss(
logits, targets=targets, seed=seed)
self._loss_dict[name] = loss
loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
seed=seed)
self.register_loss_function(loss, logits,
"multi_bernoulli_predictive_distribution",
name=name, reuse=reuse)
def make_or_get_factor(self, cls, args):
"""Insert 'cls(args)' into 'self.fisher_factors' if not already present.

View File

@ -57,30 +57,6 @@ class LossFunction(object):
"""The inputs to the loss function (excluding the targets)."""
pass
@property
def input_minibatches(self):
"""A `list` of inputs to the loss function, separated by minibatch.
Typically there will be one minibatch per tower in a multi-tower setup.
Returns a list consisting of `self.inputs` by default; `LossFunction`s
supporting registering multiple minibatches should override this method.
Returns:
A `list` of `Tensor`s representing
"""
return [self.inputs]
@property
def num_registered_minibatches(self):
"""Number of minibatches registered for this LossFunction.
Typically equal to the number of towers in a multi-tower setup.
Returns:
An `int` representing the number of registered minibatches.
"""
return len(self.input_minibatches)
def evaluate(self):
"""Evaluate the loss function on the targets."""
if self.targets is not None:
@ -474,7 +450,6 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
assert len(variance.shape) == 2, "Expect 2D variance tensor."
self._mean = mean
self._variance = variance
self._scale = math_ops.sqrt(variance)
self._targets = targets
super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
@ -484,7 +459,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def dist(self):
return normal.Normal(loc=self._mean, scale=self._scale)
return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
@property
def params(self):
@ -502,7 +477,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def _fisher_mean_factor(self):
return 1. / self._scale
return 1. / math_ops.sqrt(self._variance)
@property
def _fisher_var(self):
@ -611,36 +586,13 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
index in [0, output_size).
seed: int or None. Default random seed when sampling.
"""
self._logits_components = []
self._targets_components = []
self.register_additional_minibatch(logits, targets=targets)
self._logits = logits
self._targets = targets
super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
def register_additional_minibatch(self, logits, targets=None):
"""Register an additiona minibatch's worth of parameters.
Args:
logits: Tensor of shape [batch_size, output_size]. Parameters for
underlying distribution.
targets: None or Tensor of shape [batch_size, output_size]. Each row must
be a one-hot vector.
"""
self._logits_components.append(logits)
self._targets_components.append(targets)
@property
def _logits(self):
return array_ops.concat(self._logits_components, axis=0)
@property
def input_minibatches(self):
return self._logits_components
@property
def targets(self):
if all(target is None for target in self._targets_components):
return None
return array_ops.concat(self._targets_components, axis=0)
return self._targets
@property
def dist(self):

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
@ -50,6 +52,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
name="KFAC",
estimation_mode="gradients",
colocate_gradients_with_ops=True,
batch_size=None,
cov_devices=None,
inv_devices=None):
"""Initializes the KFAC optimizer with the given settings.
@ -91,12 +94,16 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
colocate_gradients_with_ops: Whether we should request gradients we
compute in the estimator be colocated with their respective ops.
(Default: True)
batch_size: The size of the mini-batch. Only needed when momentum_type
== 'qmodel' or when automatic adjustment is used. (Default: None)
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
Can be None, which means that no devices are specified. Only used
with (soon-to-be-depcrecated "convenience" properties).
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
Can be None, which means that no devices are specified. Only used
with (soon-to-be-depcrecated "convenience" properties).
Raises:
ValueError: If the momentum type is unsupported.
@ -110,6 +117,15 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
if variables is None:
variables = tf_variables.trainable_variables()
# Parameters to be passed to the Fisher estimator:
self._variables = variables
self._cov_ema_decay = cov_ema_decay
self._layers = layer_collection
self._estimation_mode = estimation_mode
self._colocate_gradients_with_ops = colocate_gradients_with_ops
self._cov_devices = cov_devices
self._inv_devices = inv_devices
# The below paramaters are required only if damping needs to be adapated.
# These parameters can be set by calling
# set_damping_adaptation_params() explicitly.
@ -130,17 +146,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._q_model_change = None
self._update_damping_op = None
self._layers = layer_collection
self._fisher_est = est.FisherEstimator(
lambda: self.damping,
variables,
cov_ema_decay,
layer_collection,
estimation_mode=estimation_mode,
colocate_gradients_with_ops=colocate_gradients_with_ops,
cov_devices=cov_devices,
inv_devices=inv_devices)
momentum_type = momentum_type.lower()
legal_momentum_types = ["regular", "adam", "qmodel"]
@ -154,14 +159,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
raise ValueError("Momentum must be unspecified if using a momentum_type "
"other than 'regular' or 'adam'.")
# Extra parameters of the optimizer
self._momentum = momentum
self._momentum_type = momentum_type
self._norm_constraint = norm_constraint
self._batch_size = batch_size
# this is a bit of a hack
# TODO(duckworthd): Handle this in a better way (e.g. pass it in?)
self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0]
self._losses = layer_collection.losses
with variable_scope.variable_scope(name):
self._fisher_est = est.FisherEstimator(
self._variables,
self._cov_ema_decay,
self.damping,
self._layers,
exps=(-1,),
estimation_mode=self._estimation_mode,
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
super(KfacOptimizer, self).__init__(learning_rate, name=name)
@ -178,6 +190,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
style rule described in Section 6.5 of "Optimizing Neural Networks with
Kronecker-factored Approximate Curvature".
Note that this function creates Tensorflow variables which store a few
scalars and are accessed by the ops which update the damping (as part
of the training op returned by the minimize() method).
Args:
is_chief: `Boolean`, `True` if the worker is chief.
prev_train_batch: Training data used to minimize loss in the previous
@ -199,6 +215,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
"""
if self._adapt_damping:
raise ValueError("Damping adaptation parameters already set.")
with variable_scope.variable_scope(self.get_name()):
self._adapt_damping = True
self._is_chief = is_chief
@ -221,31 +238,37 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
@property
def cov_update_thunks(self):
return self._fisher_est.cov_update_thunks
self._maybe_make_and_save_everything()
return self._cov_update_thunks
@property
def cov_update_ops(self):
return self._fisher_est.cov_update_ops
self._maybe_make_and_save_everything()
return self._cov_update_ops
@property
def cov_update_op(self):
return self._fisher_est.cov_update_op
self._maybe_make_and_save_everything()
return self._cov_update_op
@property
def inv_update_thunks(self):
return self._fisher_est.inv_update_thunks
self._maybe_make_and_save_everything()
return self._inv_update_thunks
@property
def inv_update_ops(self):
return self._fisher_est.inv_update_ops
self._maybe_make_and_save_everything()
return self._inv_update_ops
@property
def inv_update_op(self):
return self._fisher_est.inv_update_op
self._maybe_make_and_save_everything()
return self._inv_update_op
@property
def variables(self):
return self._fisher_est.variables
return self._variables
@property
def damping(self):
@ -258,25 +281,162 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
def damping_adaptation_interval(self):
return self._damping_adaptation_interval
def _maybe_make_and_save_everything(self):
if not self._fisher_est.made_vars():
warnings.warn("These convenience properties will be depcrecated soon. "
"Please use explicit op/thunk creation methods instead "
"(e.g. make_ops_and_vars_round_robin, etc).",
DeprecationWarning)
(self._cov_update_ops, self._cov_update_op, self._inv_update_ops,
self._inv_update_op, self._cov_update_thunks,
self._inv_update_thunks) = self.make_ops_and_vars_round_robin(
cov_devices=self._cov_devices,
inv_devices=self._inv_devices)
def make_ops_and_vars(self):
"""Make ops and vars with no specific device placement.
See make_ops_and_vars_round_robin for details.
Returns:
cov_update_ops: List of ops that compute the cov updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
cov_update_op: cov_update_ops grouped into a single op.
inv_update_ops: List of ops that compute the inv updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
cov_update_op: cov_update_ops grouped into a single op.
inv_update_op: inv_update_ops grouped into a single op.
"""
with variable_scope.variable_scope(self.get_name()):
return self._fisher_est.make_ops_and_vars()
def make_ops_and_vars_round_robin(self, cov_devices=None, inv_devices=None):
"""Make ops and vars with a round-robin device placement strategy.
For each factor, all of that factor's cov variables and their associated
update ops will be placed on a particular device. A new device is chosen
for each factor by cycling through list of devices in the cov_devices
argument. If cov_devices is None then no explicit device placement occurs.
An analogous strategy is followed for inverse update ops, with the list of
devices being given by the inv_devices argument.
Inverse variables on the other hand are not placed on any specific device
(they will just use the current the device placement context, whatever
that happens to be). The idea is that the inverse variable belong where
they will be accessed most often, which is the device that actually applies
the preconditioner to the gradient. The user will be responsible for setting
the device context for this.
Args:
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
Returns:
cov_update_ops: List of ops that compute the cov updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
cov_update_op: cov_update_ops grouped into a single op.
inv_update_ops: List of ops that compute the inv updates. Corresponds
one-to-one with the list of factors given by the "factors" property.
cov_update_op: cov_update_ops grouped into a single op.
inv_update_op: inv_update_ops grouped into a single op.
cov_update_thunks: Thunks that make the ops in cov_update_ops.
inv_update_thunks: Thunks that make the ops in inv_update_ops.
"""
with variable_scope.variable_scope(self.get_name()):
return self._fisher_est.make_ops_and_vars_round_robin(
cov_devices=cov_devices, inv_devices=inv_devices)
def make_vars_and_create_op_thunks_round_robin(self,
cov_devices=None,
inv_devices=None):
"""Make vars and create op thunks w/ a round-robin device placement strat.
For each factor, all of that factor's cov variables and their associated
update ops will be placed on a particular device. A new device is chosen
for each factor by cycling through list of devices in the cov_devices
argument. If cov_devices is None then no explicit device placement occurs.
An analogous strategy is followed for inverse update ops, with the list of
devices being given by the inv_devices argument.
Inverse variables on the other hand are not placed on any specific device
(they will just use the current the device placement context, whatever
that happens to be). The idea is that the inverse variable belong where
they will be accessed most often, which is the device that actually applies
the preconditioner to the gradient. The user will be responsible for setting
the device context for this.
Args:
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
Returns:
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
"""
scope = self.get_name() + "/" + self._fisher_est.name
return self._fisher_est.make_vars_and_create_op_thunks_round_robin(
scope=scope, cov_devices=cov_devices, inv_devices=inv_devices)
def ops_and_vars_thunks(self):
"""Create thunks that make the ops and vars on demand.
This function returns 4 lists of thunks: cov_variable_thunks,
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
The length of each list is the number of factors and the i-th element of
each list corresponds to the i-th factor (given by the "factors" property).
Note that the execution of these thunks must happen in a certain
partial order. The i-th element of cov_variable_thunks must execute
before the i-th element of cov_update_thunks (and also the i-th element
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
must execute before the i-th element of inv_update_thunks.
TL;DR (oversimplified): Execute the thunks according to the order that
they are returned.
Returns:
cov_variable_thunks: A list of thunks that make the cov variables.
cov_update_thunks: A list of thunks that make the cov update ops.
inv_variable_thunks: A list of thunks that make the inv variables.
inv_update_thunks: A list of thunks that make the inv update ops.
"""
scope = self.get_name() + "/" + self._fisher_est.name
return self._fisher_est.ops_and_vars_thunks(scope=scope)
def minimize(self, *args, **kwargs):
kwargs["var_list"] = kwargs.get("var_list") or self.variables
if set(kwargs["var_list"]) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
if self._adapt_damping and self._is_chief:
global_step = kwargs.get("global_step", None)
if not global_step:
raise KeyError("global_step needs to be passed to optimizer.minimize "
"if damping parameter is adapted.")
update_damping_op = self._update_damping(self._prev_train_batch,
global_step)
with ops.control_dependencies([update_damping_op]):
loss = args[0]
loss_assign_op = state_ops.assign(self._prev_loss, loss)
train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
return control_flow_ops.group(loss_assign_op, train_op)
else:
return super(KfacOptimizer, self).minimize(*args, **kwargs)
# Should this variable scope encompass everything below? Or will the super-
# class make another copy of the same name scope?
with variable_scope.variable_scope(self.get_name()):
kwargs["var_list"] = kwargs.get("var_list") or self.variables
if set(kwargs["var_list"]) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
if self._adapt_damping and self._is_chief:
global_step = kwargs.get("global_step", None)
if not global_step:
raise KeyError("global_step needs to be passed to optimizer.minimize "
"if damping parameter is adapted.")
update_damping_op = self._update_damping(self._prev_train_batch,
global_step)
with ops.control_dependencies([update_damping_op]):
loss = args[0]
loss_assign_op = state_ops.assign(self._prev_loss, loss)
train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
return control_flow_ops.group(loss_assign_op, train_op)
else:
return super(KfacOptimizer, self).minimize(*args, **kwargs)
def compute_gradients(self, *args, **kwargs):
# args[1] could be our var_list
@ -301,6 +461,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
Returns:
An `Operation` that applies the specified gradients.
"""
self._maybe_make_and_save_everything()
# In Python 3, grads_and_vars can be a zip() object which can only be
# iterated over once. By converting it to a list, we ensure that it can be
# iterated over more than once.
@ -450,7 +612,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
= qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
"""
cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables)
cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
variables)
# compute the matrix-vector products with the transposed Fisher factor
fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)

View File

@ -24,6 +24,7 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@ -482,5 +483,76 @@ def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
class PartitionedTensor(object):
"""A Tensor partitioned across its 0-th dimension."""
def __init__(self, tensors):
"""Initializes PartitionedTensor.
Args:
tensors: List of Tensors. All Tensors must agree on shape (excepting
batch dimension) and dtype.
Raises:
ValueError: If 'tensors' has length zero.
ValueError: if contents of 'tensors' don't agree on shape or dtype.
"""
if not tensors:
raise ValueError("tensors must be a list of 1+ Tensors.")
dtype = tensors[0].dtype
if not all(tensor.dtype == dtype for tensor in tensors):
raise ValueError("all tensors must have dtype = %s." % dtype)
shape = tensors[0].shape[1:]
if not all(tensor.shape[1:] == shape for tensor in tensors):
raise ValueError("All tensors must have shape = %s (excluding batch "
"dimension)." % shape)
self.tensors = tensors
self._concats = {} # {device: Tensor}
@property
def shape(self):
feature_shape = self.tensors[0].shape[1:]
batch_size = sum([tensor.shape[0] for tensor in self.tensors],
tensor_shape.Dimension(0))
return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
def get_shape(self):
return self.shape
@property
def dtype(self):
return self.tensors[0].dtype
def devices(self):
return set(tensor.device for tensor in self.tensors)
def __str__(self):
return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
def __hash__(self):
return hash(tuple(self.tensors))
def as_tensor(self, dtype=None, name=None, as_ref=False):
with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
assert not as_ref
assert dtype in [None, self.dtype]
result = array_ops.concat(self.tensors, axis=0)
# Cache 'result' if we haven't already cached a value for this device.
if result.device not in self._concats:
self._concats[result.device] = result
return self._concats[result.device]
ops.register_tensor_conversion_function(
PartitionedTensor,
lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.