- FisherEstimator now supports computing products with arbitrary matrix powers of the approximate Fisher
- Added multi-tower support to multi/RNN fully connected layers - All op creation is now done inside functions that explicitly create ops, thus allowing fine control of their placement. One result of this is that we no longer need any colocation statements (and these have been removed) - Multi-tower computations are now handled using ParitionedTensor class, which appears to be a single tensor to the FisherFactors but actually contains a list of tensors. - To achieve the above damping values are passed around as special functions that are packaged along with "ids" that can be used to uniquely identify the computation they perform. Topohash might provide a better solution for this in the future. - Variable creation in the factors is now done via special methods so we can have fine control over where these are placed - FisherEstimator now has special functions to create ops and variables using different placement strategies (currently: no strategy, round-robin, and as thunks). By default this will use the round-robin strategy and manufacture the usual convenience properties ("inv_update_ops", etc). This default behavior is to preserve backwards compatibility but in the future we should deprecate this and require the user to ask for an explicit strategy. - LossFunctions no longer make any ops in their constructors. The only make ops when evaluated. LayerCollection maintains a list of tensors/ops which we can colocate LossFunction computations with (typically their inputs) - LossFunctions no longer support multi-tower/mini-batches directly. Instead LayerCollection maintains a list of these objects, one for each tower. This solution is better since now the loss function related computations can take place exclusively on the corresponding tower. - All loss functions now support multiple towers/minibatches (via LayerCollection). - tf.gradients is passed list of loss function values instead of their sum, which will prevent extraneous gradient ops being placed on arbitrary devices. Hopefully with this change and the above one for loss functions all ops associated with gradient computations (for computing stats) will occur completely on the device that defines that part of the graph. e.g. this will do the right thing for multiple towers - I've also made sure that sensible colocation occurs for the extra ops needed by the curvature_propagation and exact estimation modes. - Variables and ops made by FisherEstimator are now placed inside of name scopes (based on the name given to FisherEstimator) - Restored old variable use count tracker implementation, thus fixing the issue with how generic registrations were handled by check_registration(). - Restored interface to FisherEstimator (which was changed in the previous CL). - Fixed bug in LazyKFacOptimizer: optional/named arguments weren't being passed in properly - Lots of other minor refactors/improvements PiperOrigin-RevId: 188310846
This commit is contained in:
parent
e52f916b87
commit
4ac1fee7f1
@ -36,6 +36,7 @@ py_test(
|
||||
srcs = ["fisher_factors_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -90,59 +90,75 @@ class EstimatorTest(test.TestCase):
|
||||
def testEstimatorInitManualRegistration(self):
|
||||
with self._graph.as_default():
|
||||
# We should be able to build an estimator for only the registered vars.
|
||||
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection)
|
||||
|
||||
# Check that we throw an error if we try to build an estimator for vars
|
||||
# that were not manually registered.
|
||||
with self.assertRaises(ValueError):
|
||||
estimator.FisherEstimator(lambda: 0.2, [self.weights, self.bias], 0.1,
|
||||
estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
|
||||
self.layer_collection)
|
||||
|
||||
# Check that we throw an error if we don't include registered variables,
|
||||
# i.e. self.weights
|
||||
with self.assertRaises(ValueError):
|
||||
estimator.FisherEstimator(lambda: 0.2, [], 0.1, self.layer_collection)
|
||||
estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
|
||||
|
||||
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
|
||||
def testVariableWrongNumberOfUses(self, mock_uses):
|
||||
with self.assertRaises(ValueError):
|
||||
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection)
|
||||
|
||||
def testInvalidEstimationMode(self):
|
||||
with self.assertRaises(ValueError):
|
||||
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
|
||||
self.layer_collection, "not_a_real_mode")
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection,
|
||||
estimation_mode="not_a_real_mode")
|
||||
|
||||
def testModeListCorrect(self):
|
||||
def testGradientsModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
est = estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
|
||||
self.layer_collection)
|
||||
self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection,
|
||||
estimation_mode="gradients")
|
||||
|
||||
def testAllModesBuild(self):
|
||||
for mode in _ALL_ESTIMATION_MODES:
|
||||
with self._graph.as_default():
|
||||
estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
|
||||
self.layer_collection, mode)
|
||||
def testEmpiricalModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection,
|
||||
estimation_mode="empirical")
|
||||
|
||||
def testCurvaturePropModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection,
|
||||
estimation_mode="curvature_prop")
|
||||
|
||||
def testExactModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
estimator.FisherEstimator([self.weights], 0.1, 0.2,
|
||||
self.layer_collection,
|
||||
estimation_mode="exact")
|
||||
|
||||
def test_cov_update_thunks(self):
|
||||
"""Ensures covariance update ops run once per global_step."""
|
||||
with self._graph.as_default(), self.test_session() as sess:
|
||||
fisher_estimator = estimator.FisherEstimator(
|
||||
damping_fn=lambda: 0.2,
|
||||
variables=[self.weights],
|
||||
layer_collection=self.layer_collection,
|
||||
damping=0.2,
|
||||
cov_ema_decay=0.0)
|
||||
|
||||
# Construct an op that executes one covariance update per step.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
(cov_variable_thunks, cov_update_op_thunks,
|
||||
_, _) = fisher_estimator.create_ops_and_vars_thunks()
|
||||
for thunk in cov_variable_thunks:
|
||||
thunk()
|
||||
cov_matrices = [
|
||||
fisher_factor.get_cov()
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
]
|
||||
cov_update_op_thunks = fisher_estimator.cov_update_thunks
|
||||
cov_update_op = control_flow_ops.case(
|
||||
[(math_ops.equal(global_step, i), thunk)
|
||||
for i, thunk in enumerate(cov_update_op_thunks)])
|
||||
@ -178,19 +194,24 @@ class EstimatorTest(test.TestCase):
|
||||
"""Ensures inverse update ops run once per global_step."""
|
||||
with self._graph.as_default(), self.test_session() as sess:
|
||||
fisher_estimator = estimator.FisherEstimator(
|
||||
damping_fn=lambda: 0.2,
|
||||
variables=[self.weights],
|
||||
layer_collection=self.layer_collection,
|
||||
damping=0.2,
|
||||
cov_ema_decay=0.0)
|
||||
|
||||
# Construct op that updates one inverse per global step.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
(cov_variable_thunks, _, inv_variable_thunks,
|
||||
inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
|
||||
for thunk in cov_variable_thunks:
|
||||
thunk()
|
||||
for thunk in inv_variable_thunks:
|
||||
thunk()
|
||||
inv_matrices = [
|
||||
matrix
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
for matrix in fisher_factor._inverses_by_damping.values()
|
||||
for matrix in fisher_factor._matpower_by_exp_and_damping.values()
|
||||
]
|
||||
inv_update_op_thunks = fisher_estimator.inv_update_thunks
|
||||
inv_update_op = control_flow_ops.case(
|
||||
[(math_ops.equal(global_step, i), thunk)
|
||||
for i, thunk in enumerate(inv_update_op_thunks)])
|
||||
|
@ -94,6 +94,9 @@ class FullFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(32)
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block._factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -112,6 +115,9 @@ class FullFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(32)
|
||||
grads = params**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block._factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -131,6 +137,9 @@ class FullFBTest(test.TestCase):
|
||||
grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
|
||||
damping = 0.5
|
||||
block.instantiate_factors((grads,), damping)
|
||||
block._factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
|
||||
@ -185,6 +194,7 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(32)
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block._factor.instantiate_cov_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -203,6 +213,7 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(32)
|
||||
grads = params**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
block._factor.instantiate_cov_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -221,6 +232,7 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
damping = 0.5
|
||||
block.instantiate_factors((grads,), damping)
|
||||
block._factor.instantiate_cov_variables()
|
||||
|
||||
cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
|
||||
sess.run(state_ops.assign(block._factor._cov, cov))
|
||||
@ -367,6 +379,7 @@ class FullyConnectedDiagonalFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(i, o)
|
||||
|
||||
block.instantiate_factors((output_grads,), damping=0.0)
|
||||
block._factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(block._factor.make_covariance_update_op(0.0))
|
||||
@ -394,7 +407,7 @@ class EmbeddingKFACFBTest(test.TestCase):
|
||||
# Instantiate factor's variables. Ensure it doesn't fail.
|
||||
grads = outputs**2.
|
||||
damping = array_ops.constant(0.)
|
||||
block.instantiate_factors(([grads],), damping)
|
||||
block.instantiate_factors(((grads,),), damping)
|
||||
|
||||
def testMultiplyInverse(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
@ -412,7 +425,12 @@ class EmbeddingKFACFBTest(test.TestCase):
|
||||
# Instantiate factor's variables. Ensure it doesn't fail.
|
||||
grads = outputs**2.
|
||||
damping = array_ops.constant(0.)
|
||||
block.instantiate_factors(([grads],), damping)
|
||||
block.instantiate_factors(((grads,),), damping)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
# Create a sparse update.
|
||||
indices = array_ops.constant([1, 3, 4])
|
||||
@ -456,7 +474,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
|
||||
def testInstantiateFactorsNoBias(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -467,7 +485,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
|
||||
def testMultiplyInverseTuple(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
@ -477,7 +495,13 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -503,7 +527,12 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -527,10 +556,17 @@ class FullyConnectedKFACBasicFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
damping = 0. # This test is only valid without damping.
|
||||
block.instantiate_factors(([grads],), damping)
|
||||
block.instantiate_factors(((grads,),), damping)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
|
||||
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
||||
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
sess.run(block._input_factor.make_inverse_update_ops())
|
||||
sess.run(block._output_factor.make_inverse_update_ops())
|
||||
|
||||
@ -718,6 +754,7 @@ class ConvDiagonalFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(i, o)
|
||||
|
||||
block.instantiate_factors((output_grads,), damping=0.0)
|
||||
block._factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(block._factor.make_covariance_update_op(0.0))
|
||||
@ -759,7 +796,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
'SAME')
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -786,7 +828,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.assertFalse(block._has_bias)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -809,7 +856,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self.assertTrue(block._has_bias)
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(([grads],), 0.5)
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
# Make sure our inverse is something other than the identity.
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
@ -832,7 +884,12 @@ class ConvKFCBasicFBTest(test.TestCase):
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
grads = outputs**2
|
||||
damping = 0. # This test is only valid without damping.
|
||||
block.instantiate_factors(([grads],), damping)
|
||||
block.instantiate_factors(((grads,),), damping)
|
||||
block._input_factor.instantiate_cov_variables()
|
||||
block._output_factor.instantiate_cov_variables()
|
||||
block.register_inverse()
|
||||
block._input_factor.instantiate_inv_variables()
|
||||
block._output_factor.instantiate_inv_variables()
|
||||
|
||||
sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
|
||||
sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
|
||||
@ -857,9 +914,9 @@ class FullyConnectedSeriesFBTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
inputs = array_ops.constant([1., 2.])
|
||||
outputs = array_ops.constant([3., 4.])
|
||||
block = fb.FullyConnectedSeriesFB(
|
||||
lc.LayerCollection(), inputs=[inputs], outputs=[outputs])
|
||||
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
|
||||
block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
|
||||
block.register_additional_minibatch([inputs], [outputs])
|
||||
self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
|
||||
|
||||
def testInstantiateFactorsHasBias(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -868,11 +925,10 @@ class FullyConnectedSeriesFBTest(test.TestCase):
|
||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||
block = fb.FullyConnectedSeriesFB(
|
||||
lc.LayerCollection(),
|
||||
inputs=[inputs],
|
||||
outputs=[outputs],
|
||||
has_bias=True)
|
||||
block.register_additional_minibatch([inputs], [outputs])
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
block.instantiate_factors((((grads,),),), 0.5)
|
||||
|
||||
def testInstantiateFactorsNoBias(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -881,11 +937,10 @@ class FullyConnectedSeriesFBTest(test.TestCase):
|
||||
outputs = array_ops.constant([[3., 4.], [5., 6.]])
|
||||
block = fb.FullyConnectedSeriesFB(
|
||||
lc.LayerCollection(),
|
||||
inputs=[inputs],
|
||||
outputs=[outputs],
|
||||
has_bias=False)
|
||||
block.register_additional_minibatch([inputs], [outputs])
|
||||
grads = outputs**2
|
||||
block.instantiate_factors(((grads,),), 0.5)
|
||||
block.instantiate_factors((((grads,),),), 0.5)
|
||||
|
||||
|
||||
def as_tensors(tensor_or_tuple):
|
||||
|
@ -21,8 +21,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -33,32 +33,8 @@ from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MaybeColocateTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS
|
||||
|
||||
def tearDown(self):
|
||||
ff.set_global_constants(
|
||||
colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs)
|
||||
|
||||
def testFalse(self):
|
||||
ff.set_global_constants(colocate_cov_ops_with_inputs=False)
|
||||
with tf_ops.Graph().as_default():
|
||||
a = constant_op.constant([2.0], name='a')
|
||||
with ff.maybe_colocate_with(a):
|
||||
b = constant_op.constant(3.0, name='b')
|
||||
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
|
||||
self.assertEqual([b'loc:@b'], b.op.colocation_groups())
|
||||
|
||||
def testTrue(self):
|
||||
ff.set_global_constants(colocate_cov_ops_with_inputs=True)
|
||||
with tf_ops.Graph().as_default():
|
||||
a = constant_op.constant([2.0], name='a')
|
||||
with ff.maybe_colocate_with(a):
|
||||
b = constant_op.constant(3.0, name='b')
|
||||
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
|
||||
self.assertEqual([b'loc:@a'], b.op.colocation_groups())
|
||||
def make_damping_func(damping):
|
||||
return fb._package_func(lambda: damping, damping)
|
||||
|
||||
|
||||
class FisherFactorTestingDummy(ff.FisherFactor):
|
||||
@ -98,10 +74,13 @@ class FisherFactorTestingDummy(ff.FisherFactor):
|
||||
def right_multiply(self, x, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def left_multiply_inverse(self, x, damping):
|
||||
def left_multiply_matpower(self, x, exp, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def right_multiply_inverse(self, x, damping):
|
||||
def right_multiply_matpower(self, x, exp, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def instantiate_inv_variables(self):
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
@ -246,21 +225,24 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
factor = InverseProvidingFactorTestingDummy(shape)
|
||||
factor_var_scope = 'dummy/a_b_c'
|
||||
|
||||
dampings = 0.1, 1e-1, 0.00001, 1e-5
|
||||
damping_funcs = [make_damping_func(0.1),
|
||||
make_damping_func(0.1),
|
||||
make_damping_func(1e-5),
|
||||
make_damping_func(1e-5)]
|
||||
for damping_func in damping_funcs:
|
||||
factor.register_inverse(damping_func)
|
||||
|
||||
for damping in dampings:
|
||||
factor.register_damped_inverse(damping)
|
||||
factor.instantiate_inv_variables()
|
||||
|
||||
self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys()))
|
||||
inv = factor._inverses_by_damping[dampings[0]]
|
||||
self.assertEqual(inv, factor._inverses_by_damping[dampings[1]])
|
||||
self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]])
|
||||
self.assertEqual(factor._inverses_by_damping[dampings[2]],
|
||||
factor._inverses_by_damping[dampings[3]])
|
||||
inv = factor.get_inverse(damping_funcs[0])
|
||||
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
|
||||
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
|
||||
self.assertEqual(factor.get_inverse(damping_funcs[2]),
|
||||
factor.get_inverse(damping_funcs[3]))
|
||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
factor_var_scope)
|
||||
self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]],
|
||||
factor_vars)
|
||||
self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
|
||||
set(factor_vars))
|
||||
self.assertEqual(shape, inv.get_shape())
|
||||
|
||||
def testRegisterMatpower(self):
|
||||
@ -270,17 +252,22 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
factor = InverseProvidingFactorTestingDummy(shape)
|
||||
factor_var_scope = 'dummy/a_b_c'
|
||||
|
||||
factor.register_matpower(1, 0.5)
|
||||
factor.register_matpower(2, 0.5)
|
||||
# TODO(b/74201126): Change to using the same func for both once
|
||||
# Topohash is in place.
|
||||
damping_func_1 = make_damping_func(0.5)
|
||||
damping_func_2 = make_damping_func(0.5)
|
||||
|
||||
factor.register_matpower(-0.5, damping_func_1)
|
||||
factor.register_matpower(2, damping_func_2)
|
||||
|
||||
factor.instantiate_inv_variables()
|
||||
|
||||
self.assertEqual(
|
||||
set([(1, 0.5), (2, 0.5)]),
|
||||
set(factor._matpower_by_exp_and_damping.keys()))
|
||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
factor_var_scope)
|
||||
matpower1 = factor.get_matpower(1, 0.5)
|
||||
matpower2 = factor.get_matpower(2, 0.5)
|
||||
self.assertListEqual([matpower1, matpower2], factor_vars)
|
||||
matpower1 = factor.get_matpower(-0.5, damping_func_1)
|
||||
matpower2 = factor.get_matpower(2, damping_func_2)
|
||||
|
||||
self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
|
||||
|
||||
self.assertEqual(shape, matpower1.get_shape())
|
||||
self.assertEqual(shape, matpower2.get_shape())
|
||||
@ -299,17 +286,24 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
|
||||
damping_funcs = []
|
||||
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
|
||||
factor.register_damped_inverse(1. / i)
|
||||
damping_funcs.append(make_damping_func(1./i))
|
||||
|
||||
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
||||
factor.register_inverse(damping_funcs[i])
|
||||
|
||||
factor.instantiate_inv_variables()
|
||||
ops = factor.make_inverse_update_ops()
|
||||
self.assertEqual(1, len(ops))
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_invs = []
|
||||
sess.run(ops)
|
||||
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
|
||||
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||
new_invs.append(sess.run(factor._inverses_by_damping[1. / i]))
|
||||
new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
|
||||
|
||||
# We want to see that the new invs are all different from each other.
|
||||
for i in range(len(new_invs)):
|
||||
for j in range(i + 1, len(new_invs)):
|
||||
@ -324,14 +318,16 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
|
||||
damping = 0.5
|
||||
damping_func = make_damping_func(damping)
|
||||
|
||||
factor.register_matpower(exp, damping)
|
||||
factor.register_matpower(exp, damping_func)
|
||||
factor.instantiate_inv_variables()
|
||||
ops = factor.make_inverse_update_ops()
|
||||
self.assertEqual(1, len(ops))
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(ops[0])
|
||||
matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)])
|
||||
matpower = sess.run(factor.get_matpower(exp, damping_func))
|
||||
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
|
||||
self.assertAllClose(matpower, matpower_np)
|
||||
|
||||
@ -342,18 +338,21 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
|
||||
factor.register_damped_inverse(0)
|
||||
damping_func = make_damping_func(0)
|
||||
|
||||
factor.register_inverse(damping_func)
|
||||
factor.instantiate_inv_variables()
|
||||
ops = factor.make_inverse_update_ops()
|
||||
self.assertEqual(1, len(ops))
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||
old_inv = sess.run(factor._inverses_by_damping[0])
|
||||
old_inv = sess.run(factor.get_inverse(damping_func))
|
||||
self.assertAllClose(
|
||||
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
|
||||
|
||||
sess.run(ops)
|
||||
new_inv = sess.run(factor._inverses_by_damping[0])
|
||||
new_inv = sess.run(factor.get_inverse(damping_func))
|
||||
self.assertAllClose(new_inv, np.linalg.inv(cov))
|
||||
|
||||
|
||||
@ -364,6 +363,7 @@ class FullFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.FullFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testFullFactorInitFloat64(self):
|
||||
@ -372,6 +372,7 @@ class FullFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.FullFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([6, 6], cov.get_shape().as_list())
|
||||
@ -381,6 +382,7 @@ class FullFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([1., 2.], name='a/b/c')
|
||||
factor = ff.FullFactor((tensor,), 2)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -394,6 +396,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
|
||||
|
||||
def testNaiveDiagonalFactorInitFloat64(self):
|
||||
@ -402,6 +405,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov_var()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([6, 1], cov.get_shape().as_list())
|
||||
@ -411,6 +415,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([1., 2.], name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 2)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -423,7 +428,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
|
||||
with tf_ops.Graph().as_default():
|
||||
input_ids = array_ops.constant([[0], [1], [4]])
|
||||
vocab_size = 5
|
||||
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
||||
factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov_var()
|
||||
self.assertEqual(cov.shape.as_list(), [vocab_size])
|
||||
|
||||
@ -431,7 +437,8 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
|
||||
with tf_ops.Graph().as_default():
|
||||
input_ids = array_ops.constant([[0], [1], [4]])
|
||||
vocab_size = 5
|
||||
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
||||
factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
|
||||
factor.instantiate_cov_variables()
|
||||
cov_update_op = factor.make_covariance_update_op(0.0)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -450,6 +457,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual(final_shape, cov.get_shape().as_list())
|
||||
@ -467,6 +475,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -477,6 +486,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
factor = ff.FullyConnectedKroneckerFactor((tensor,))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -491,6 +501,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
tensor, (1, 2, 3, 4), 3, 2, has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
|
||||
factor.get_cov().get_shape().as_list())
|
||||
|
||||
@ -500,6 +511,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
|
||||
factor.get_cov().get_shape().as_list())
|
||||
|
||||
@ -510,6 +522,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
|
||||
@ -522,6 +535,7 @@ class ConvInputKroneckerFactorTest(test.TestCase):
|
||||
np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -533,8 +547,9 @@ class ConvInputKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant(
|
||||
np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
|
||||
factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1],
|
||||
'SAME')
|
||||
factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1),
|
||||
[1, 1, 1, 1], 'SAME')
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -548,6 +563,7 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
|
||||
factor = ff.ConvOutputKroneckerFactor((tensor,))
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testConvOutputKroneckerFactorInitFloat64(self):
|
||||
@ -556,6 +572,7 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
|
||||
factor = ff.ConvOutputKroneckerFactor((tensor,))
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([5, 5], cov.get_shape().as_list())
|
||||
@ -565,13 +582,14 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
with self.assertRaises(IndexError):
|
||||
ff.ConvOutputKroneckerFactor(tensor)
|
||||
ff.ConvOutputKroneckerFactor((tensor,))
|
||||
|
||||
def testMakeCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
|
||||
factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -586,6 +604,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
tensor_list = [tensor]
|
||||
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testFullyConnectedMultiKFInitFloat64(self):
|
||||
@ -595,6 +614,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
tensor_list = [tensor]
|
||||
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([3, 3], cov.get_shape().as_list())
|
||||
@ -605,6 +625,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
tensor_list = [tensor]
|
||||
factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
@ -616,6 +637,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
tensor_list = [tensor]
|
||||
factor = ff.FullyConnectedMultiKF((tensor_list,))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
|
@ -237,16 +237,16 @@ class LayerCollectionTest(test.TestCase):
|
||||
|
||||
# Create a new loss function by name.
|
||||
lc.register_categorical_predictive_distribution(logits, name='loss1')
|
||||
self.assertEqual(1, len(lc.losses))
|
||||
self.assertEqual(1, len(lc.towers_by_loss))
|
||||
|
||||
# Add logits to same loss function.
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits, name='loss1', reuse=True)
|
||||
self.assertEqual(1, len(lc.losses))
|
||||
self.assertEqual(1, len(lc.towers_by_loss))
|
||||
|
||||
# Add another new loss function.
|
||||
lc.register_categorical_predictive_distribution(logits, name='loss2')
|
||||
self.assertEqual(2, len(lc.losses))
|
||||
self.assertEqual(2, len(lc.towers_by_loss))
|
||||
|
||||
def testLossFunctionWithoutName(self):
|
||||
"""Ensure loss functions get unique names if 'name' not specified."""
|
||||
@ -298,13 +298,9 @@ class LayerCollectionTest(test.TestCase):
|
||||
name='loss1',
|
||||
reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
self.assertEqual(len(lc.losses), 1)
|
||||
loss = lc.losses[0]
|
||||
|
||||
self.assertEqual(len(lc.towers_by_loss), 1)
|
||||
# Three successful registrations.
|
||||
self.assertEqual(loss.params.shape.as_list(),
|
||||
[3 * batch_size, output_size])
|
||||
self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size])
|
||||
self.assertEqual(len(lc.towers_by_loss[0]), 3)
|
||||
|
||||
def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -479,17 +475,6 @@ class LayerCollectionTest(test.TestCase):
|
||||
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
|
||||
|
||||
def testGetUseCountMap(self):
|
||||
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {
|
||||
'a': MockFisherBlock(),
|
||||
('a', 'c'): MockFisherBlock(),
|
||||
('b', 'c'): MockFisherBlock()
|
||||
}
|
||||
use_count_map = lc.get_use_count_map()
|
||||
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
|
||||
|
||||
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
|
||||
x = variable_scope.get_variable('x', shape=())
|
||||
y = variable_scope.get_variable('y', shape=())
|
||||
|
@ -24,7 +24,6 @@ from tensorflow.contrib.kfac.python.ops import loss_functions
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -97,22 +96,6 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||
# difficult to say if the output is correct or not...
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
def testMultiMinibatchRegistration(self):
|
||||
"""Ensure this loss function supports registering multiple minibatches."""
|
||||
with ops.Graph().as_default():
|
||||
tower_logits = []
|
||||
loss = None
|
||||
num_towers = 5
|
||||
for _ in range(num_towers):
|
||||
logits = random_ops.random_uniform(shape=[2, 3])
|
||||
tower_logits.append(logits)
|
||||
if loss is None:
|
||||
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
|
||||
else:
|
||||
loss.register_additional_minibatch(logits)
|
||||
self.assertListEqual(loss.input_minibatches, tower_logits)
|
||||
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
||||
|
||||
def testMultiplyFisherSingleVector(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
logits = np.array([1., 2., 3.])
|
||||
@ -203,23 +186,5 @@ class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||
# difficult to say if the output is correct or not...
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
def testMultiMinibatchRegistration(self):
|
||||
"""Ensure this loss function supports registering multiple minibatches."""
|
||||
with ops.Graph().as_default():
|
||||
tower_logits = []
|
||||
loss = None
|
||||
num_towers = 5
|
||||
for _ in range(num_towers):
|
||||
logits = random_ops.random_uniform(shape=[2, 3])
|
||||
tower_logits.append(logits)
|
||||
if loss is None:
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
logits)
|
||||
else:
|
||||
loss.register_additional_minibatch(logits)
|
||||
self.assertListEqual(loss.input_minibatches, tower_logits)
|
||||
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -65,6 +66,13 @@ class _DeviceContextGenerator(object):
|
||||
yield
|
||||
|
||||
|
||||
def _make_thunk_on_device(func, device):
|
||||
def thunk():
|
||||
with tf_ops.device(device):
|
||||
return func()
|
||||
return thunk
|
||||
|
||||
|
||||
class FisherEstimator(object):
|
||||
"""Fisher estimator class supporting various approximations of the Fisher.
|
||||
|
||||
@ -83,26 +91,35 @@ class FisherEstimator(object):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
damping_fn,
|
||||
variables,
|
||||
cov_ema_decay,
|
||||
damping,
|
||||
layer_collection,
|
||||
exps=(-1,),
|
||||
estimation_mode="gradients",
|
||||
colocate_gradients_with_ops=True,
|
||||
cov_devices=None,
|
||||
inv_devices=None):
|
||||
name="FisherEstimator"):
|
||||
"""Create a FisherEstimator object.
|
||||
|
||||
Args:
|
||||
damping_fn: Function, accepts no arguments and returns damping value.
|
||||
variables: A list of the variables for which to estimate the Fisher. This
|
||||
must match the variables registered in layer_collection (if it is not
|
||||
None).
|
||||
cov_ema_decay: The decay factor used when calculating the covariance
|
||||
estimate moving averages.
|
||||
damping: float. The damping factor used to stabilize training due to
|
||||
errors in the local approximation with the Fisher information matrix,
|
||||
and to regularize the update direction by making it closer to the
|
||||
gradient. (Higher damping means the update looks more like a standard
|
||||
gradient update - see Tikhonov regularization.)
|
||||
layer_collection: The layer collection object, which holds the fisher
|
||||
blocks, kronecker factors, and losses associated with the
|
||||
graph.
|
||||
exps: List of floats or ints. These represent the different matrix
|
||||
powers of the approximate Fisher that the FisherEstimator will be able
|
||||
to multiply vectors by. If the user asks for a matrix power other
|
||||
one of these (or 1, which is always supported), there will be a
|
||||
failure. (Default: (-1,))
|
||||
estimation_mode: The type of estimator to use for the Fishers. Can be
|
||||
'gradients', 'empirical', 'curvature_prop', or 'exact'.
|
||||
(Default: 'gradients'). 'gradients' is the basic estimation approach
|
||||
@ -121,19 +138,15 @@ class FisherEstimator(object):
|
||||
equal to the output dimension, roughly speaking.
|
||||
colocate_gradients_with_ops: Whether we should request gradients be
|
||||
colocated with their respective ops. (Default: True)
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
|
||||
name: A string. A name given to this estimator, which is added to the
|
||||
variable scope when constructing variables and ops.
|
||||
(Default: "FisherEstimator")
|
||||
Raises:
|
||||
ValueError: If no losses have been registered with layer_collection.
|
||||
"""
|
||||
self._damping_fn = damping_fn
|
||||
self._cov_ema_decay = cov_ema_decay
|
||||
self._variables = variables
|
||||
self._cov_ema_decay = cov_ema_decay
|
||||
self._damping = damping
|
||||
self._estimation_mode = estimation_mode
|
||||
self._layers = layer_collection
|
||||
self._layers.create_subgraph()
|
||||
@ -146,30 +159,13 @@ class FisherEstimator(object):
|
||||
}
|
||||
self._colocate_gradients_with_ops = colocate_gradients_with_ops
|
||||
|
||||
# TODO(b/70674513): Factor device placement outside of this class.
|
||||
self._cov_device_context_generator = _DeviceContextGenerator(cov_devices)
|
||||
if inv_devices == cov_devices:
|
||||
self._inv_device_context_generator = self._cov_device_context_generator
|
||||
else:
|
||||
self._inv_device_context_generator = _DeviceContextGenerator(inv_devices)
|
||||
self._made_vars = False
|
||||
self._exps = exps
|
||||
|
||||
self._name = name
|
||||
|
||||
self._instantiate_factors()
|
||||
|
||||
self.cov_update_thunks = [
|
||||
self._create_cov_update_thunk(factor)
|
||||
for factor in self._layers.get_factors()
|
||||
]
|
||||
self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks]
|
||||
self.cov_update_op = control_flow_ops.group(
|
||||
self.cov_update_ops, name="cov_update_op")
|
||||
|
||||
self.inv_update_thunks = [
|
||||
self._create_inv_update_thunk(factor)
|
||||
for factor in self._layers.get_factors()
|
||||
]
|
||||
self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks]
|
||||
self.inv_update_op = control_flow_ops.group(
|
||||
self.inv_update_ops, name="inv_update_op")
|
||||
self._register_matrix_functions()
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
@ -177,7 +173,21 @@ class FisherEstimator(object):
|
||||
|
||||
@property
|
||||
def damping(self):
|
||||
return self._damping_fn()
|
||||
return self._damping
|
||||
|
||||
@property
|
||||
def blocks(self):
|
||||
"""All registered FisherBlocks."""
|
||||
return self._layers.get_blocks()
|
||||
|
||||
@property
|
||||
def factors(self):
|
||||
"""All registered FisherFactors."""
|
||||
return self._layers.get_factors()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def _apply_transformation(self, vecs_and_vars, transform):
|
||||
"""Applies an block-wise transformation to the corresponding vectors.
|
||||
@ -212,9 +222,7 @@ class FisherEstimator(object):
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
|
||||
return self._apply_transformation(vecs_and_vars,
|
||||
lambda fb, vec: fb.multiply_inverse(vec))
|
||||
return self.multiply_matpower(-1, vecs_and_vars)
|
||||
|
||||
def multiply(self, vecs_and_vars):
|
||||
"""Multiplies the vectors by the corresponding (damped) blocks.
|
||||
@ -226,9 +234,22 @@ class FisherEstimator(object):
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
return self.multiply_matpower(1, vecs_and_vars)
|
||||
|
||||
return self._apply_transformation(vecs_and_vars,
|
||||
lambda fb, vec: fb.multiply(vec))
|
||||
def multiply_matpower(self, exp, vecs_and_vars):
|
||||
"""Multiplies the vecs by the corresponding matrix powers of the blocks.
|
||||
|
||||
Args:
|
||||
exp: A float representing the power to raise the blocks by before
|
||||
multiplying it by the vector.
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def _instantiate_factors(self):
|
||||
"""Instantiates FisherFactors' variables.
|
||||
@ -236,9 +257,9 @@ class FisherEstimator(object):
|
||||
Raises:
|
||||
ValueError: If estimation_mode was improperly specified at construction.
|
||||
"""
|
||||
fisher_blocks_list = self._layers.get_blocks()
|
||||
blocks = self.blocks
|
||||
tensors_to_compute_grads = [
|
||||
fb.tensors_to_compute_grads() for fb in fisher_blocks_list
|
||||
block.tensors_to_compute_grads() for block in blocks
|
||||
]
|
||||
|
||||
try:
|
||||
@ -248,45 +269,275 @@ class FisherEstimator(object):
|
||||
raise ValueError("Unrecognized value {} for estimation_mode.".format(
|
||||
self._estimation_mode))
|
||||
|
||||
# TODO(b/68033310): This loop round-robins the "concat" operations which
|
||||
# gather the inputs for the cov_updates. In future, we might do these
|
||||
# computations locally then communicate the results, which would require a
|
||||
# modification to this code.
|
||||
for grads_list, fb in zip(grads_lists, fisher_blocks_list):
|
||||
with self._cov_device_context_generator():
|
||||
fb.instantiate_factors(grads_list, self.damping)
|
||||
for grads_list, block in zip(grads_lists, blocks):
|
||||
block.instantiate_factors(grads_list, self.damping)
|
||||
|
||||
def _create_cov_update_thunk(self, factor):
|
||||
def _check_vars_unmade_and_set_made_flag(self):
|
||||
if self._made_vars:
|
||||
raise Exception("Already made variables.")
|
||||
self._made_vars = True
|
||||
|
||||
def made_vars(self):
|
||||
return self._made_vars
|
||||
|
||||
def _register_matrix_functions(self):
|
||||
for exp in self._exps:
|
||||
for block in self.blocks:
|
||||
block.register_matpower(exp)
|
||||
|
||||
def make_ops_and_vars(self, scope=None):
|
||||
"""Make ops and vars with no specific device placement.
|
||||
|
||||
See make_ops_and_vars_round_robin for further details.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All variables will be created,
|
||||
and all ops will execute, inside of a variable scope of the given
|
||||
name. (Default: None)
|
||||
Returns:
|
||||
cov_update_ops: List of ops that compute the cov updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
cov_update_op: cov_update_ops grouped into a single op.
|
||||
inv_update_ops: List of ops that compute the inv updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
inv_update_op: inv_update_ops grouped into a single op.
|
||||
cov_update_thunks: Thunks that make the ops in cov_update_ops.
|
||||
inv_update_thunks: Thunks that make the ops in inv_update_ops.
|
||||
"""
|
||||
return self.make_ops_and_vars_round_robin(scope=scope)
|
||||
|
||||
# TODO(b/70674513): Factor device placement outside of this class.
|
||||
def make_ops_and_vars_round_robin(self, scope=None, cov_devices=None,
|
||||
inv_devices=None):
|
||||
"""Make ops and vars with a round-robin device placement strategy.
|
||||
|
||||
For each factor, all of that factor's cov variables and their associated
|
||||
update ops will be placed on a particular device. A new device is chosen
|
||||
for each factor by cycling through list of devices in the cov_devices
|
||||
argument. If cov_devices is None then no explicit device placement occurs.
|
||||
|
||||
An analogous strategy is followed for inverse update ops, with the list of
|
||||
devices being given by the inv_devices argument.
|
||||
|
||||
Inverse variables on the other hand are not placed on any specific device
|
||||
(they will just use the current the device placement context, whatever
|
||||
that happens to be). The idea is that the inverse variable belong where
|
||||
they will be accessed most often, which is the device that actually applies
|
||||
the preconditioner to the gradient. The user will be responsible for setting
|
||||
the device context for this.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All variables will be created,
|
||||
and all ops will execute, inside of a variable scope of the given
|
||||
name. (Default: None)
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
|
||||
Returns:
|
||||
cov_update_ops: List of ops that compute the cov updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
cov_update_op: cov_update_ops grouped into a single op.
|
||||
inv_update_ops: List of ops that compute the inv updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
inv_update_op: inv_update_ops grouped into a single op.
|
||||
cov_update_thunks: Thunks that make the ops in cov_update_ops.
|
||||
inv_update_thunks: Thunks that make the ops in inv_update_ops.
|
||||
"""
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = self.make_vars_and_create_op_thunks_round_robin(
|
||||
scope=scope,
|
||||
cov_devices=cov_devices,
|
||||
inv_devices=inv_devices)
|
||||
cov_update_ops = [thunk() for thunk in cov_update_thunks]
|
||||
inv_update_ops = [thunk() for thunk in inv_update_thunks]
|
||||
|
||||
scope = self.name if scope is None else scope
|
||||
with variable_scope.variable_scope(scope):
|
||||
cov_update_op = control_flow_ops.group(cov_update_ops,
|
||||
name="cov_update_op")
|
||||
inv_update_op = control_flow_ops.group(inv_update_ops,
|
||||
name="inv_update_op")
|
||||
|
||||
return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
|
||||
cov_update_thunks, inv_update_thunks)
|
||||
|
||||
def make_vars_and_create_op_thunks_round_robin(self,
|
||||
scope=None,
|
||||
cov_devices=None,
|
||||
inv_devices=None):
|
||||
"""Make vars and create op thunks w/ a round-robin device placement strat.
|
||||
|
||||
For each factor, all of that factor's cov variables and their associated
|
||||
update ops will be placed on a particular device. A new device is chosen
|
||||
for each factor by cycling through list of devices in the cov_devices
|
||||
argument. If cov_devices is None then no explicit device placement occurs.
|
||||
|
||||
An analogous strategy is followed for inverse update ops, with the list of
|
||||
devices being given by the inv_devices argument.
|
||||
|
||||
Inverse variables on the other hand are not placed on any specific device
|
||||
(they will just use the current the device placement context, whatever
|
||||
that happens to be). The idea is that the inverse variable belong where
|
||||
they will be accessed most often, which is the device that actually applies
|
||||
the preconditioner to the gradient. The user will be responsible for setting
|
||||
the device context for this.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All variables will be created,
|
||||
and all thunks will execute, inside of a variable scope of the given
|
||||
name. (Default: None)
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
Returns:
|
||||
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
"""
|
||||
|
||||
(cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
|
||||
inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
|
||||
|
||||
if cov_devices:
|
||||
cov_update_thunks = []
|
||||
for cov_variable_thunk, cov_update_thunk, device in zip(
|
||||
cov_variable_thunks_raw, cov_update_thunks_raw,
|
||||
itertools.cycle(cov_devices)):
|
||||
with tf_ops.device(device):
|
||||
cov_variable_thunk()
|
||||
cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
|
||||
device))
|
||||
else:
|
||||
for cov_variable_thunk in cov_variable_thunks_raw:
|
||||
cov_variable_thunk()
|
||||
cov_update_thunks = cov_update_thunks_raw
|
||||
|
||||
for inv_variable_thunk in inv_variable_thunks_raw:
|
||||
inv_variable_thunk()
|
||||
|
||||
if inv_devices:
|
||||
inv_update_thunks = []
|
||||
for inv_update_thunk, device in zip(inv_update_thunks_raw,
|
||||
itertools.cycle(inv_devices)):
|
||||
inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
|
||||
device))
|
||||
else:
|
||||
inv_update_thunks = inv_update_thunks_raw
|
||||
|
||||
return cov_update_thunks, inv_update_thunks
|
||||
|
||||
def create_ops_and_vars_thunks(self, scope=None):
|
||||
"""Create thunks that make the ops and vars on demand.
|
||||
|
||||
This function returns 4 lists of thunks: cov_variable_thunks,
|
||||
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
|
||||
|
||||
The length of each list is the number of factors and the i-th element of
|
||||
each list corresponds to the i-th factor (given by the "factors" property).
|
||||
|
||||
Note that the execution of these thunks must happen in a certain
|
||||
partial order. The i-th element of cov_variable_thunks must execute
|
||||
before the i-th element of cov_update_thunks (and also the i-th element
|
||||
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
|
||||
must execute before the i-th element of inv_update_thunks.
|
||||
|
||||
TL;DR (oversimplified): Execute the thunks according to the order that
|
||||
they are returned.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All thunks will execute inside
|
||||
of a variable scope of the given name. (Default: None)
|
||||
Returns:
|
||||
cov_variable_thunks: A list of thunks that make the cov variables.
|
||||
cov_update_thunks: A list of thunks that make the cov update ops.
|
||||
inv_variable_thunks: A list of thunks that make the inv variables.
|
||||
inv_update_thunks: A list of thunks that make the inv update ops.
|
||||
"""
|
||||
self._check_vars_unmade_and_set_made_flag()
|
||||
|
||||
scope = self.name if scope is None else scope
|
||||
|
||||
cov_variable_thunks = [
|
||||
self._create_cov_variable_thunk(factor, scope)
|
||||
for factor in self.factors
|
||||
]
|
||||
cov_update_thunks = [
|
||||
self._create_cov_update_thunk(factor, scope) for factor in self.factors
|
||||
]
|
||||
inv_variable_thunks = [
|
||||
self._create_inv_variable_thunk(factor, scope)
|
||||
for factor in self.factors
|
||||
]
|
||||
inv_update_thunks = [
|
||||
self._create_inv_update_thunk(factor, scope) for factor in self.factors
|
||||
]
|
||||
|
||||
return (cov_variable_thunks, cov_update_thunks,
|
||||
inv_variable_thunks, inv_update_thunks)
|
||||
|
||||
def _create_cov_variable_thunk(self, factor, scope):
|
||||
"""Constructs a covariance variable thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with variable_scope.variable_scope(scope):
|
||||
return factor.instantiate_cov_variables()
|
||||
|
||||
return thunk
|
||||
|
||||
def _create_cov_update_thunk(self, factor, scope):
|
||||
"""Constructs a covariance update thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with tf_ops.name_scope(
|
||||
"create_cov_update_thunk", values=[self._cov_ema_decay]):
|
||||
with variable_scope.variable_scope(scope):
|
||||
return factor.make_covariance_update_op(self._cov_ema_decay)
|
||||
|
||||
return thunk
|
||||
|
||||
def _create_inv_update_thunk(self, factor):
|
||||
def _create_inv_variable_thunk(self, factor, scope):
|
||||
"""Constructs a inverse variable thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with variable_scope.variable_scope(scope):
|
||||
return factor.instantiate_inv_variables()
|
||||
|
||||
return thunk
|
||||
|
||||
def _create_inv_update_thunk(self, factor, scope):
|
||||
"""Constructs an inverse update thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with tf_ops.name_scope("create_inv_update_thunk"):
|
||||
with self._inv_device_context_generator():
|
||||
return control_flow_ops.group(factor.make_inverse_update_ops())
|
||||
with variable_scope.variable_scope(scope):
|
||||
return control_flow_ops.group(factor.make_inverse_update_ops())
|
||||
|
||||
return thunk
|
||||
|
||||
def _get_grads_lists_gradients(self, tensors):
|
||||
# Passing in a list of loss values is better than passing in the sum as
|
||||
# the latter creates unnessesary ops on the default device
|
||||
grads_flat = gradients_impl.gradients(
|
||||
self._layers.total_sampled_loss(),
|
||||
self._layers.eval_losses_on_samples(),
|
||||
nest.flatten(tensors),
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all = nest.pack_sequence_as(tensors, grads_flat)
|
||||
return tuple((grad,) for grad in grads_all)
|
||||
|
||||
def _get_grads_lists_empirical(self, tensors):
|
||||
# Passing in a list of loss values is better than passing in the sum as
|
||||
# the latter creates unnessesary ops on the default device
|
||||
grads_flat = gradients_impl.gradients(
|
||||
self._layers.total_loss(),
|
||||
self._layers.eval_losses(),
|
||||
nest.flatten(tensors),
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all = nest.pack_sequence_as(tensors, grads_flat)
|
||||
@ -295,9 +546,10 @@ class FisherEstimator(object):
|
||||
def _get_transformed_random_signs(self):
|
||||
transformed_random_signs = []
|
||||
for loss in self._layers.losses:
|
||||
transformed_random_signs.append(
|
||||
loss.multiply_fisher_factor(
|
||||
utils.generate_random_signs(loss.fisher_factor_inner_shape)))
|
||||
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
|
||||
transformed_random_signs.append(
|
||||
loss.multiply_fisher_factor(
|
||||
utils.generate_random_signs(loss.fisher_factor_inner_shape)))
|
||||
return transformed_random_signs
|
||||
|
||||
def _get_grads_lists_curvature_prop(self, tensors):
|
||||
@ -316,13 +568,14 @@ class FisherEstimator(object):
|
||||
# Loop over all coordinates of all losses.
|
||||
grads_all = []
|
||||
for loss in self._layers.losses:
|
||||
for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
|
||||
transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
|
||||
index)
|
||||
grads_flat = gradients_impl.gradients(
|
||||
loss.inputs,
|
||||
nest.flatten(tensors),
|
||||
grad_ys=transformed_one_hot,
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
|
||||
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
|
||||
for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
|
||||
transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
|
||||
index)
|
||||
grads_flat = gradients_impl.gradients(
|
||||
loss.inputs,
|
||||
nest.flatten(tensors),
|
||||
grad_ys=transformed_one_hot,
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
|
||||
return zip(*grads_all)
|
||||
|
@ -121,12 +121,44 @@ def compute_pi_adjusted_damping(left_cov, right_cov, damping):
|
||||
return (damping, damping)
|
||||
|
||||
|
||||
class PackagedFunc(object):
|
||||
"""A Python thunk with a stable ID.
|
||||
|
||||
Enables stable names for lambdas.
|
||||
"""
|
||||
|
||||
def __init__(self, func, func_id):
|
||||
"""Initializes PackagedFunc.
|
||||
|
||||
Args:
|
||||
func: a zero-arg Python function.
|
||||
func_id: a hashable, function that produces a hashable, or a list/tuple
|
||||
thereof.
|
||||
"""
|
||||
self._func = func
|
||||
func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
|
||||
self._func_id = func_id
|
||||
|
||||
def __call__(self):
|
||||
return self._func()
|
||||
|
||||
@property
|
||||
def func_id(self):
|
||||
"""A hashable identifier for this function."""
|
||||
return tuple(elt() if callable(elt) else elt for elt in self._func_id)
|
||||
|
||||
|
||||
def _package_func(func, func_id):
|
||||
return PackagedFunc(func, func_id)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class FisherBlock(object):
|
||||
"""Abstract base class for objects modeling approximate Fisher matrix blocks.
|
||||
|
||||
Subclasses must implement multiply_inverse(), instantiate_factors(), and
|
||||
tensors_to_compute_grads() methods.
|
||||
Subclasses must implement register_matpower, multiply_matpower,
|
||||
instantiate_factors, tensors_to_compute_grads, and num_registered_minibatches
|
||||
methods.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_collection):
|
||||
@ -145,6 +177,32 @@ class FisherBlock(object):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def register_matpower(self, exp):
|
||||
"""Registers a matrix power to be computed by the block.
|
||||
|
||||
Args:
|
||||
exp: A float representing the power to raise the block by.
|
||||
"""
|
||||
pass
|
||||
|
||||
def register_inverse(self):
|
||||
"""Registers a matrix inverse to be computed by the block."""
|
||||
self.register_matpower(-1)
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_matpower(self, vector, exp):
|
||||
"""Multiplies the vector by the (damped) matrix-power of the block.
|
||||
|
||||
Args:
|
||||
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
|
||||
exp: A float representing the power to raise the block by before
|
||||
multiplying it by the vector.
|
||||
|
||||
Returns:
|
||||
The vector left-multiplied by the (damped) matrix-power of the block.
|
||||
"""
|
||||
pass
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
"""Multiplies the vector by the (damped) inverse of the block.
|
||||
|
||||
@ -154,9 +212,8 @@ class FisherBlock(object):
|
||||
Returns:
|
||||
The vector left-multiplied by the (damped) inverse of the block.
|
||||
"""
|
||||
pass
|
||||
return self.multiply_matpower(vector, -1)
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply(self, vector):
|
||||
"""Multiplies the vector by the (damped) block.
|
||||
|
||||
@ -166,7 +223,7 @@ class FisherBlock(object):
|
||||
Returns:
|
||||
The vector left-multiplied by the (damped) block.
|
||||
"""
|
||||
pass
|
||||
return self.multiply_matpower(vector, 1)
|
||||
|
||||
@abc.abstractmethod
|
||||
def tensors_to_compute_grads(self):
|
||||
@ -207,21 +264,18 @@ class FullFB(FisherBlock):
|
||||
super(FullFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
self._damping = damping
|
||||
self._damping_func = _package_func(lambda: damping, (damping,))
|
||||
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullFactor, (grads_list, self._batch_size))
|
||||
self._factor.register_damped_inverse(damping)
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
out_flat = self._factor.left_multiply_inverse(
|
||||
vector_flat, self._damping)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
def register_matpower(self, exp):
|
||||
self._factor.register_matpower(exp, self._damping_func)
|
||||
|
||||
def multiply(self, vector):
|
||||
def multiply_matpower(self, vector, exp):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
out_flat = self._factor.left_multiply(
|
||||
vector_flat, self._damping)
|
||||
out_flat = self._factor.left_multiply_matpower(
|
||||
vector_flat, exp, self._damping_func)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
|
||||
def full_fisher_block(self):
|
||||
@ -271,22 +325,20 @@ class NaiveDiagonalFB(FisherBlock):
|
||||
super(NaiveDiagonalFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
self._damping = damping
|
||||
self._damping_func = _package_func(lambda: damping, (damping,))
|
||||
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
print("vector_flat: %s" % vector_flat)
|
||||
out_flat = self._factor.left_multiply_inverse(
|
||||
vector_flat, self._damping)
|
||||
print("out_flat: %s" % out_flat)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def multiply(self, vector):
|
||||
def multiply_matpower(self, vector, exp):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
out_flat = self._factor.left_multiply(
|
||||
vector_flat, self._damping)
|
||||
out_flat = self._factor.left_multiply_matpower(
|
||||
vector_flat, exp, self._damping_func)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
|
||||
def full_fisher_block(self):
|
||||
@ -312,7 +364,89 @@ class NaiveDiagonalFB(FisherBlock):
|
||||
return math_ops.reduce_sum(self._batch_sizes)
|
||||
|
||||
|
||||
class FullyConnectedDiagonalFB(FisherBlock):
|
||||
class InputOutputMultiMinibatch(object):
|
||||
"""Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__inputs = []
|
||||
self.__outputs = []
|
||||
super(InputOutputMultiMinibatch, self).__init__(*args, **kwargs)
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
"""Tensors to compute derivative of loss with respect to."""
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
result = len(self._inputs)
|
||||
assert result == len(self._outputs)
|
||||
return result
|
||||
|
||||
@property
|
||||
def _inputs(self):
|
||||
return self.__inputs
|
||||
|
||||
@property
|
||||
def _outputs(self):
|
||||
return self.__outputs
|
||||
|
||||
def _package_minibatches(self, grads_list):
|
||||
"""Constructs PartitionedTensor for inputs, grads_list.
|
||||
|
||||
The purpose of this method is to package up the towers/minibatch dimension
|
||||
of these arrays into PartitionedTensor objects.
|
||||
|
||||
Args:
|
||||
grads_list: 2-D list of Tensors. First index is for source, second
|
||||
index for tower.
|
||||
|
||||
Returns:
|
||||
inputs: PartitionedTensor.
|
||||
grads_list: Tuple of PartitionedTensors, one per source.
|
||||
"""
|
||||
inputs = utils.PartitionedTensor(self._inputs)
|
||||
grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list)
|
||||
|
||||
return inputs, grads_list
|
||||
|
||||
def _package_minibatches_multi(self, grads_list):
|
||||
"""Constructs PartitionedTensors for inputs, grads_list.
|
||||
|
||||
The purpose of this method is to package up the towers/minibatch dimension
|
||||
of these arrays into PartitionedTensor objects.
|
||||
|
||||
This version of this function is for use with FisherBlocks that deal with
|
||||
multiple uses or time-steps. One PartitionedTensor is created for each
|
||||
use/time-step.
|
||||
|
||||
Args:
|
||||
grads_list: 3-D tuple of Tensors. First index is for source, second
|
||||
index is for tower, third is for use/time-step.
|
||||
|
||||
Returns:
|
||||
inputs: A tuple of PartitionedTensor's, one per use/time-step.
|
||||
grads_list: 2-D tuple of PartitionedTensors. First index is for source,
|
||||
second is for use/time-step.
|
||||
"""
|
||||
# self._inputs is a 2-D tuple. First index is tower/mini-batch, second is
|
||||
# use/time-step.
|
||||
inputs = self._inputs
|
||||
num_uses = len(inputs[0])
|
||||
assert all(len(input_) == num_uses for input_ in inputs)
|
||||
assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
|
||||
|
||||
inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
|
||||
grads_list = tuple(tuple(utils.PartitionedTensor(grad)
|
||||
for grad in zip(*grads)) for grads in grads_list)
|
||||
|
||||
return inputs, grads_list
|
||||
|
||||
|
||||
class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
|
||||
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
||||
|
||||
Estimates the Fisher Information matrix's diagonal entries for a fully
|
||||
@ -344,79 +478,45 @@ class FullyConnectedDiagonalFB(FisherBlock):
|
||||
has_bias: Whether the component Kronecker factors have an additive bias.
|
||||
(Default: False)
|
||||
"""
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._has_bias = has_bias
|
||||
|
||||
super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
inputs, grads_list = self._package_minibatches(grads_list)
|
||||
|
||||
self._damping = damping
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedDiagonalFactor,
|
||||
(inputs, grads_list, self._has_bias))
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
"""Approximate damped inverse Fisher-vector product.
|
||||
self._damping_func = _package_func(lambda: damping, (damping,))
|
||||
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
"""Multiplies the vector by the (damped) matrix-power of the block.
|
||||
|
||||
Args:
|
||||
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
|
||||
[input_size, output_size] corresponding to layer's weights. If not, a
|
||||
2-tuple of the former and a Tensor of shape [output_size] corresponding
|
||||
to the layer's bias.
|
||||
exp: A scalar representing the power to raise the block before multiplying
|
||||
it by the vector.
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape, corresponding to the inverse Fisher-vector
|
||||
product.
|
||||
The vector left-multiplied by the (damped) matrix-power of the block.
|
||||
"""
|
||||
reshaped_vec = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._factor.left_multiply_inverse(
|
||||
reshaped_vec, self._damping)
|
||||
reshaped_out = self._factor.left_multiply_matpower(
|
||||
reshaped_vec, exp, self._damping_func)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def multiply(self, vector):
|
||||
"""Approximate damped Fisher-vector product.
|
||||
|
||||
Args:
|
||||
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
|
||||
[input_size, output_size] corresponding to layer's weights. If not, a
|
||||
2-tuple of the former and a Tensor of shape [output_size] corresponding
|
||||
to the layer's bias.
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape, corresponding to the Fisher-vector product.
|
||||
"""
|
||||
reshaped_vec = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._factor.left_multiply(
|
||||
reshaped_vec, self._damping)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
"""Tensors to compute derivative of loss with respect to."""
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
|
||||
matrix-multiply.
|
||||
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
result = len(self._inputs)
|
||||
assert result == len(self._outputs)
|
||||
return result
|
||||
|
||||
|
||||
class ConvDiagonalFB(FisherBlock):
|
||||
class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
|
||||
"""FisherBlock for convolutional layers using a diagonal approx.
|
||||
|
||||
Estimates the Fisher Information matrix's diagonal entries for a convolutional
|
||||
@ -454,8 +554,6 @@ class ConvDiagonalFB(FisherBlock):
|
||||
strides: The stride size in this layer (1-D Tensor of length 4).
|
||||
padding: The padding in this layer (e.g. "SAME").
|
||||
"""
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._strides = tuple(strides) if isinstance(strides, list) else strides
|
||||
self._padding = padding
|
||||
self._has_bias = isinstance(params, (tuple, list))
|
||||
@ -466,55 +564,38 @@ class ConvDiagonalFB(FisherBlock):
|
||||
super(ConvDiagonalFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
# Concatenate inputs, grads_list into single Tensors.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
|
||||
# Infer number of locations upon which convolution is applied.
|
||||
inputs_shape = tuple(inputs.shape.as_list())
|
||||
inputs_shape = tuple(self._inputs[0].shape.as_list())
|
||||
self._num_locations = (
|
||||
inputs_shape[1] * inputs_shape[2] //
|
||||
(self._strides[1] * self._strides[2]))
|
||||
|
||||
self._damping = (self._num_locations
|
||||
* normalize_damping(damping, self._num_locations))
|
||||
inputs, grads_list = self._package_minibatches(grads_list)
|
||||
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.ConvDiagonalFactor,
|
||||
(inputs, grads_list, self._filter_shape, self._strides, self._padding,
|
||||
self._has_bias))
|
||||
(inputs, grads_list, self._filter_shape, self._strides,
|
||||
self._padding, self._has_bias))
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
def damping_func():
|
||||
return self._num_locations * normalize_damping(damping,
|
||||
self._num_locations)
|
||||
|
||||
damping_id = (self._num_locations, "mult", "normalize_damping", damping,
|
||||
self._num_locations)
|
||||
self._damping_func = _package_func(damping_func, damping_id)
|
||||
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
reshaped_vect = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._factor.left_multiply_inverse(
|
||||
reshaped_vect, self._damping)
|
||||
reshaped_out = self._factor.left_multiply_matpower(
|
||||
reshaped_vect, exp, self._damping_func)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def multiply(self, vector):
|
||||
reshaped_vect = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._factor.left_multiply(
|
||||
reshaped_vect, self._damping)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
|
||||
the convolution.
|
||||
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
|
||||
preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return len(self._inputs)
|
||||
|
||||
|
||||
class KroneckerProductFB(FisherBlock):
|
||||
"""A base class for FisherBlocks with separate input and output factors.
|
||||
@ -523,22 +604,40 @@ class KroneckerProductFB(FisherBlock):
|
||||
output factors.
|
||||
"""
|
||||
|
||||
def _register_damped_input_and_output_inverses(self, damping):
|
||||
"""Registers damped inverses for both the input and output factors.
|
||||
def __init__(self, layer_collection):
|
||||
super(KroneckerProductFB, self).__init__(layer_collection)
|
||||
|
||||
Sets the instance members _input_damping and _output_damping. Requires the
|
||||
instance members _input_factor and _output_factor.
|
||||
def _setup_damping(self, damping, normalization=None):
|
||||
"""Makes functions that compute the damping values for both factors."""
|
||||
def compute_damping():
|
||||
if normalization is not None:
|
||||
maybe_normalized_damping = normalize_damping(damping, normalization)
|
||||
else:
|
||||
maybe_normalized_damping = damping
|
||||
|
||||
Args:
|
||||
damping: The base damping factor (float or Tensor) for the damped inverse.
|
||||
"""
|
||||
self._input_damping, self._output_damping = compute_pi_adjusted_damping(
|
||||
self._input_factor.get_cov(),
|
||||
self._output_factor.get_cov(),
|
||||
damping**0.5)
|
||||
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
|
||||
self._output_factor.get_cov(),
|
||||
maybe_normalized_damping**0.5)
|
||||
|
||||
self._input_factor.register_damped_inverse(self._input_damping)
|
||||
self._output_factor.register_damped_inverse(self._output_damping)
|
||||
if normalization is not None:
|
||||
damping_id = ("compute_pi_adjusted_damping",
|
||||
"cov", self._input_factor.name,
|
||||
"cov", self._output_factor.name,
|
||||
"normalize_damping", damping, normalization, "power", 0.5)
|
||||
else:
|
||||
damping_id = ("compute_pi_adjusted_damping",
|
||||
"cov", self._input_factor.name,
|
||||
"cov", self._output_factor.name,
|
||||
damping, "power", 0.5)
|
||||
|
||||
self._input_damping_func = _package_func(lambda: compute_damping()[0],
|
||||
damping_id + ("ref", 0))
|
||||
self._output_damping_func = _package_func(lambda: compute_damping()[1],
|
||||
damping_id + ("ref", 1))
|
||||
|
||||
def register_matpower(self, exp):
|
||||
self._input_factor.register_matpower(exp, self._input_damping_func)
|
||||
self._output_factor.register_matpower(exp, self._output_damping_func)
|
||||
|
||||
@property
|
||||
def _renorm_coeff(self):
|
||||
@ -552,28 +651,15 @@ class KroneckerProductFB(FisherBlock):
|
||||
"""
|
||||
return 1.0
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
def multiply_matpower(self, vector, exp):
|
||||
reshaped_vector = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._output_factor.right_multiply_inverse(
|
||||
reshaped_vector,
|
||||
self._output_damping)
|
||||
reshaped_out = self._input_factor.left_multiply_inverse(
|
||||
reshaped_out, self._input_damping)
|
||||
if self._renorm_coeff != 1.0:
|
||||
reshaped_out /= math_ops.cast(
|
||||
self._renorm_coeff, dtype=reshaped_out.dtype)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def multiply(self, vector):
|
||||
reshaped_vector = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._output_factor.right_multiply(
|
||||
reshaped_vector,
|
||||
self._output_damping)
|
||||
reshaped_out = self._input_factor.left_multiply(
|
||||
reshaped_out, self._input_damping)
|
||||
reshaped_out = self._output_factor.right_multiply_matpower(
|
||||
reshaped_vector, exp, self._output_damping_func)
|
||||
reshaped_out = self._input_factor.left_multiply_matpower(
|
||||
reshaped_out, exp, self._input_damping_func)
|
||||
if self._renorm_coeff != 1.0:
|
||||
reshaped_out *= math_ops.cast(
|
||||
self._renorm_coeff, dtype=reshaped_out.dtype)
|
||||
self._renorm_coeff**exp, dtype=reshaped_out.dtype)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def full_fisher_block(self):
|
||||
@ -590,7 +676,7 @@ class KroneckerProductFB(FisherBlock):
|
||||
right_factor)
|
||||
|
||||
|
||||
class EmbeddingKFACFB(KroneckerProductFB):
|
||||
class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
|
||||
"""K-FAC FisherBlock for embedding layers.
|
||||
|
||||
This FisherBlock is similar to EmbeddingKFACFB, except that its
|
||||
@ -608,8 +694,6 @@ class EmbeddingKFACFB(KroneckerProductFB):
|
||||
Fisher information matrix to which this FisherBlock belongs.
|
||||
vocab_size: int. Size of vocabulary for this embedding layer.
|
||||
"""
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
super(EmbeddingKFACFB, self).__init__(layer_collection)
|
||||
@ -624,41 +708,18 @@ class EmbeddingKFACFB(KroneckerProductFB):
|
||||
damping: 0-D Tensor or float. 'damping' * identity is approximately added
|
||||
to this FisherBlock's Fisher approximation.
|
||||
"""
|
||||
# TODO(b/68033310): Validate which of,
|
||||
# (1) summing on a single device (as below), or
|
||||
# (2) on each device in isolation and aggregating
|
||||
# is faster.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
inputs, grads_list = self._package_minibatches(grads_list)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor( #
|
||||
fisher_factors.EmbeddingInputKroneckerFactor, #
|
||||
((inputs,), self._vocab_size))
|
||||
(inputs, self._vocab_size))
|
||||
self._output_factor = self._layer_collection.make_or_get_factor( #
|
||||
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||
(grads_list,))
|
||||
self._register_damped_input_and_output_inverses(damping)
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
|
||||
matrix-multiply.
|
||||
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return len(self._inputs)
|
||||
self._setup_damping(damping)
|
||||
|
||||
|
||||
class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
|
||||
"""K-FAC FisherBlock for fully-connected (dense) layers.
|
||||
|
||||
This uses the Kronecker-factorized approximation from the original
|
||||
@ -674,8 +735,6 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
has_bias: Whether the component Kronecker factors have an additive bias.
|
||||
(Default: False)
|
||||
"""
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._has_bias = has_bias
|
||||
|
||||
super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
|
||||
@ -690,12 +749,7 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
damping: 0-D Tensor or float. 'damping' * identity is approximately added
|
||||
to this FisherBlock's Fisher approximation.
|
||||
"""
|
||||
# TODO(b/68033310): Validate which of,
|
||||
# (1) summing on a single device (as below), or
|
||||
# (2) on each device in isolation and aggregating
|
||||
# is faster.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
inputs, grads_list = self._package_minibatches(grads_list)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor( #
|
||||
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||
@ -703,28 +757,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
self._output_factor = self._layer_collection.make_or_get_factor( #
|
||||
fisher_factors.FullyConnectedKroneckerFactor, #
|
||||
(grads_list,))
|
||||
self._register_damped_input_and_output_inverses(damping)
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, input_size]. Inputs to the
|
||||
matrix-multiply.
|
||||
outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return len(self._inputs)
|
||||
self._setup_damping(damping)
|
||||
|
||||
|
||||
class ConvKFCBasicFB(KroneckerProductFB):
|
||||
class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
|
||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||
|
||||
Estimates the Fisher Information matrix's blog for a convolutional
|
||||
@ -761,8 +797,6 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
||||
strides: The stride size in this layer (1-D Tensor of length 4).
|
||||
padding: The padding in this layer (1-D of Tensor length 4).
|
||||
"""
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._strides = tuple(strides) if isinstance(strides, list) else strides
|
||||
self._padding = padding
|
||||
self._has_bias = isinstance(params, (tuple, list))
|
||||
@ -773,17 +807,12 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
||||
super(ConvKFCBasicFB, self).__init__(layer_collection)
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
# TODO(b/68033310): Validate which of,
|
||||
# (1) summing on a single device (as below), or
|
||||
# (2) on each device in isolation and aggregating
|
||||
# is faster.
|
||||
inputs = _concat_along_batch_dim(self._inputs)
|
||||
grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
|
||||
|
||||
# Infer number of locations upon which convolution is applied.
|
||||
self._num_locations = num_conv_locations(inputs.shape.as_list(),
|
||||
self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
|
||||
self._strides)
|
||||
|
||||
inputs, grads_list = self._package_minibatches(grads_list)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.ConvInputKroneckerFactor,
|
||||
(inputs, self._filter_shape, self._strides, self._padding,
|
||||
@ -791,60 +820,12 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
|
||||
|
||||
damping = normalize_damping(damping, self._num_locations)
|
||||
self._register_damped_input_and_output_inverses(damping)
|
||||
self._damping = damping
|
||||
self._setup_damping(damping, normalization=self._num_locations)
|
||||
|
||||
@property
|
||||
def _renorm_coeff(self):
|
||||
return self._num_locations
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def register_additional_minibatch(self, inputs, outputs):
|
||||
"""Registers an additional minibatch to the FisherBlock.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to
|
||||
the convolution.
|
||||
outputs: Tensor of shape [batch_size, height, width, output_size]. Layer
|
||||
preactivations.
|
||||
"""
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return len(self._inputs)
|
||||
|
||||
|
||||
def _concat_along_batch_dim(tensor_list):
|
||||
"""Concatenate tensors along batch (first) dimension.
|
||||
|
||||
Args:
|
||||
tensor_list: list of Tensors or list of tuples of Tensors.
|
||||
|
||||
Returns:
|
||||
Tensor or tuple of Tensors.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'tensor_list' is empty.
|
||||
|
||||
"""
|
||||
if not tensor_list:
|
||||
raise ValueError(
|
||||
"Cannot concatenate Tensors if there are no Tensors to concatenate.")
|
||||
|
||||
if isinstance(tensor_list[0], (tuple, list)):
|
||||
# [(tensor1a, tensor1b),
|
||||
# (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b)
|
||||
return tuple(
|
||||
array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list))
|
||||
else:
|
||||
# [tensor1, tensor2] --> tensor
|
||||
return array_ops.concat(tensor_list, axis=0)
|
||||
|
||||
|
||||
def num_conv_locations(input_shape, strides):
|
||||
"""Returns the number of spatial locations a 2D Conv kernel is applied to.
|
||||
@ -859,49 +840,35 @@ def num_conv_locations(input_shape, strides):
|
||||
return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
|
||||
|
||||
|
||||
class FullyConnectedMultiIndepFB(KroneckerProductFB):
|
||||
class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
|
||||
"""FisherBlock for fully-connected layers that share parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_collection, inputs, outputs, has_bias=False):
|
||||
def __init__(self, layer_collection, has_bias=False):
|
||||
"""Creates a FullyConnectedMultiIndepFB block.
|
||||
|
||||
Args:
|
||||
layer_collection: LayerCollection instance.
|
||||
inputs: list or tuple of Tensors. Each Tensor has shape [batch_size,
|
||||
inputs_size].
|
||||
outputs: list or tuple of Tensors. Each Tensor has shape [batch_size,
|
||||
outputs_size].
|
||||
has_bias: bool. If True, estimates Fisher with respect to a bias
|
||||
parameter as well as the layer's parameters.
|
||||
"""
|
||||
|
||||
assert len(inputs) == len(outputs)
|
||||
# We need to make sure inputs and outputs are tuples and not lists so that
|
||||
# they get hashed by layer_collection.make_or_get_factor properly.
|
||||
self._inputs = tuple(inputs)
|
||||
self._outputs = tuple(outputs)
|
||||
self._has_bias = has_bias
|
||||
self._num_uses = len(inputs)
|
||||
|
||||
super(FullyConnectedMultiIndepFB, self).__init__(layer_collection)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
# TODO(b/69411207): Add support for registering additional minibatches.
|
||||
return 1
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
|
||||
self._num_uses = len(self._inputs[0])
|
||||
inputs, grads_list = self._package_minibatches_multi(grads_list)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedMultiKF,
|
||||
((self._inputs,), self._has_bias))
|
||||
((inputs,), self._has_bias))
|
||||
|
||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedMultiKF, (grads_list,))
|
||||
|
||||
damping = normalize_damping(damping, self._num_uses)
|
||||
self._register_damped_input_and_output_inverses(damping)
|
||||
self._setup_damping(damping, normalization=self._num_uses)
|
||||
|
||||
@property
|
||||
def _renorm_coeff(self):
|
||||
@ -910,9 +877,6 @@ class FullyConnectedMultiIndepFB(KroneckerProductFB):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def num_inputs(self):
|
||||
return len(self._inputs)
|
||||
|
||||
|
||||
class SeriesFBApproximation(enum.IntEnum):
|
||||
"""See FullyConnectedSeriesFB.__init__ for description and usage."""
|
||||
@ -920,22 +884,20 @@ class SeriesFBApproximation(enum.IntEnum):
|
||||
option2 = 2
|
||||
|
||||
|
||||
class FullyConnectedSeriesFB(FisherBlock):
|
||||
class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
|
||||
"""FisherBlock for fully-connected layers that share parameters across time.
|
||||
|
||||
See the following preprint for details:
|
||||
https://openreview.net/pdf?id=HyMTkQZAb
|
||||
|
||||
See the end of the appendix of the paper for a pseudo-code of the
|
||||
algorithm being implemented by multiply_inverse here. Note that we are
|
||||
algorithm being implemented by multiply_matpower here. Note that we are
|
||||
using pre-computed versions of certain matrix-matrix products to speed
|
||||
things up. This is explicitly explained wherever it is done.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_collection,
|
||||
inputs,
|
||||
outputs,
|
||||
has_bias=False,
|
||||
option=SeriesFBApproximation.option2):
|
||||
"""Constructs a new `FullyConnectedSeriesFB`.
|
||||
@ -943,10 +905,6 @@ class FullyConnectedSeriesFB(FisherBlock):
|
||||
Args:
|
||||
layer_collection: The collection of all layers in the K-FAC approximate
|
||||
Fisher information matrix to which this FisherBlock belongs.
|
||||
inputs: List of tensors of shape [batch_size, input_size].
|
||||
Inputs to the layer.
|
||||
outputs: List of tensors of shape [batch_size, input_size].
|
||||
Outputs of the layer (before activations).
|
||||
has_bias: Whether the layer includes a bias parameter.
|
||||
option: A `SeriesFBApproximation` specifying the simplifying assumption
|
||||
to be used in this block. `option1` approximates the cross-covariance
|
||||
@ -955,48 +913,61 @@ class FullyConnectedSeriesFB(FisherBlock):
|
||||
3.5 of the paper for more details.
|
||||
"""
|
||||
|
||||
assert len(inputs) == len(outputs)
|
||||
# We need to make sure inputs and outputs are tuples and not lists so that
|
||||
# they get hashed by layer_collection.make_or_get_factor properly.
|
||||
self._inputs = tuple(inputs)
|
||||
self._outputs = tuple(outputs)
|
||||
self._has_bias = has_bias
|
||||
self._num_timesteps = len(inputs)
|
||||
self._option = option
|
||||
|
||||
super(FullyConnectedSeriesFB, self).__init__(layer_collection)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
# TODO(b/69411207): Add support for registering additional minibatches.
|
||||
return 1
|
||||
|
||||
def instantiate_factors(self, grads_list, damping):
|
||||
|
||||
self._num_timesteps = len(self._inputs[0])
|
||||
inputs, grads_list = self._package_minibatches_multi(grads_list)
|
||||
|
||||
self._input_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias))
|
||||
fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
|
||||
self._input_factor.register_cov_dt1()
|
||||
|
||||
self._output_factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.FullyConnectedMultiKF, (grads_list,))
|
||||
self._output_factor.register_cov_dt1()
|
||||
|
||||
damping = normalize_damping(damping, self._num_timesteps)
|
||||
self._damping_input, self._damping_output = compute_pi_adjusted_damping(
|
||||
self._input_factor.get_cov(),
|
||||
self._output_factor.get_cov(),
|
||||
damping**0.5)
|
||||
def compute_damping():
|
||||
normalized_damping = normalize_damping(damping, self._num_timesteps)
|
||||
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
|
||||
self._output_factor.get_cov(),
|
||||
normalized_damping**0.5)
|
||||
|
||||
damping_id = ("compute_pi_adjusted_damping",
|
||||
"cov", self._input_factor.name,
|
||||
"cov", self._output_factor.name,
|
||||
"normalize_damping",
|
||||
damping, self._num_timesteps, "power", 0.5)
|
||||
self._input_damping_func = _package_func(lambda: compute_damping()[0],
|
||||
damping_id + ("ref", 0))
|
||||
self._output_damping_func = _package_func(lambda: compute_damping()[1],
|
||||
damping_id + ("ref", 1))
|
||||
|
||||
def register_matpower(self, exp):
|
||||
if exp != -1:
|
||||
raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
|
||||
"multiplications.")
|
||||
|
||||
if self._option == SeriesFBApproximation.option1:
|
||||
self._input_factor.register_option1quants(self._damping_input)
|
||||
self._output_factor.register_option1quants(self._damping_output)
|
||||
self._input_factor.register_option1quants(self._input_damping_func)
|
||||
self._output_factor.register_option1quants(self._output_damping_func)
|
||||
elif self._option == SeriesFBApproximation.option2:
|
||||
self._input_factor.register_option2quants(self._damping_input)
|
||||
self._output_factor.register_option2quants(self._damping_output)
|
||||
self._input_factor.register_option2quants(self._input_damping_func)
|
||||
self._output_factor.register_option2quants(self._output_damping_func)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognized FullyConnectedSeriesFB approximation: {}".format(
|
||||
self._option))
|
||||
|
||||
def multiply_inverse(self, vector):
|
||||
def multiply_matpower(self, vector, exp):
|
||||
if exp != -1:
|
||||
raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
|
||||
"multiplications.")
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
Z = utils.layer_params_to_mat2d(vector)
|
||||
@ -1008,8 +979,10 @@ class FullyConnectedSeriesFB(FisherBlock):
|
||||
if self._option == SeriesFBApproximation.option1:
|
||||
|
||||
# Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G.
|
||||
L_A, psi_A = self._input_factor.get_option1quants(self._damping_input)
|
||||
L_G, psi_G = self._output_factor.get_option1quants(self._damping_output)
|
||||
L_A, psi_A = self._input_factor.get_option1quants(
|
||||
self._input_damping_func)
|
||||
L_G, psi_G = self._output_factor.get_option1quants(
|
||||
self._output_damping_func)
|
||||
|
||||
def gamma(x):
|
||||
# We are assuming that each case has the same number of time-steps.
|
||||
@ -1046,9 +1019,10 @@ class FullyConnectedSeriesFB(FisherBlock):
|
||||
|
||||
# Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1),
|
||||
# and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G.
|
||||
P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input)
|
||||
P_A, K_A, mu_A = self._input_factor.get_option2quants(
|
||||
self._input_damping_func)
|
||||
P_G, K_G, mu_G = self._output_factor.get_option2quants(
|
||||
self._damping_output)
|
||||
self._output_damping_func)
|
||||
|
||||
# Our approach differs superficially from the pseudo-code in the paper
|
||||
# in order to reduce the total number of matrix-matrix multiplies.
|
||||
@ -1102,11 +1076,5 @@ class FullyConnectedSeriesFB(FisherBlock):
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
def multiply(self, vector):
|
||||
raise NotImplementedError
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
def num_inputs(self):
|
||||
return len(self._inputs)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -130,6 +130,8 @@ class LayerCollection(object):
|
||||
fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
|
||||
losses: a list of LossFunction objects. The loss to be optimized is their
|
||||
sum.
|
||||
loss_colocation_ops: ops to colocate loss function evaluations with. These
|
||||
will typically be the inputs to the losses.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -148,14 +150,21 @@ class LayerCollection(object):
|
||||
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
|
||||
self._default_fully_connected_multi_approximation = (
|
||||
APPROX_KRONECKER_SERIES_2_NAME)
|
||||
self.loss_colocation_ops = {}
|
||||
self._vars_to_uses = defaultdict(lambda: 0)
|
||||
|
||||
with variable_scope.variable_scope(None, default_name=name) as scope:
|
||||
self._var_scope = scope.name
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
"""LossFunctions registered with this LayerCollection."""
|
||||
return list(self._loss_dict.values())
|
||||
"""Tuple of LossFunction objects registered with this LayerCollection."""
|
||||
return nest.flatten(self.towers_by_loss)
|
||||
|
||||
@property
|
||||
def towers_by_loss(self):
|
||||
"""Tuple across losses of LossFunction objects registered to each tower."""
|
||||
return tuple(tuple(lst) for lst in self._loss_dict.values())
|
||||
|
||||
@property
|
||||
def registered_variables(self):
|
||||
@ -290,23 +299,74 @@ class LayerCollection(object):
|
||||
self.fisher_blocks[layer_key] = fisher_block
|
||||
return fisher_block
|
||||
|
||||
def get_use_count_map(self):
|
||||
"""Returns a dict of variables to their number of registrations."""
|
||||
# TODO(b/70283403): Reimplement this in the old way, where each
|
||||
# registration function would be responsible for incrementing the count.
|
||||
# Also, this version has a bug: it won't do the right thing for generic
|
||||
# registration for parameters that are shared. i.e. it won't set the use
|
||||
# count to infinity.
|
||||
vars_to_uses = defaultdict(int)
|
||||
for key, block in six.iteritems(self.fisher_blocks):
|
||||
n = (
|
||||
block.num_inputs()*block.num_registered_minibatches if isinstance(
|
||||
block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB))
|
||||
else block.num_registered_minibatches)
|
||||
key = utils.ensure_sequence(key)
|
||||
for k in key:
|
||||
vars_to_uses[k] += n
|
||||
return vars_to_uses
|
||||
def register_loss_function(self,
|
||||
loss,
|
||||
colocation_op,
|
||||
base_name,
|
||||
name=None,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a LossFunction object.
|
||||
|
||||
Args:
|
||||
loss: The LossFunction object.
|
||||
colocation_op: The op to colocate the loss function's computations with.
|
||||
base_name: The name to derive a new unique name from is the name argument
|
||||
is None.
|
||||
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
|
||||
a new name is generated. (Default: None)
|
||||
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
|
||||
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
|
||||
tf.get_variable_scope().reuse.
|
||||
|
||||
Raises:
|
||||
ValueError: If reuse == True and name == None.
|
||||
ValueError: If reuse == True and seed != None.
|
||||
KeyError: If reuse == True and no existing LossFunction with 'name' found.
|
||||
KeyError: If reuse == False and existing LossFunction with 'name' found.
|
||||
"""
|
||||
|
||||
name = name or self._graph.unique_name(base_name)
|
||||
|
||||
if reuse == VARIABLE_SCOPE:
|
||||
reuse = variable_scope.get_variable_scope().reuse
|
||||
|
||||
if reuse:
|
||||
if name is None:
|
||||
raise ValueError(
|
||||
"If reuse is enabled, loss function's name must be set.")
|
||||
|
||||
loss_list = self._loss_dict.get(name, None)
|
||||
|
||||
if loss_list is None:
|
||||
raise KeyError(
|
||||
"Unable to find loss function named {}. Register a new loss "
|
||||
"function with reuse=False.".format(name))
|
||||
else:
|
||||
if name in self._loss_dict:
|
||||
raise KeyError(
|
||||
"Loss function named {} already exists. Set reuse=True to append "
|
||||
"another minibatch/tower.".format(name))
|
||||
|
||||
loss_list = []
|
||||
self._loss_dict[name] = loss_list
|
||||
|
||||
loss_list.append(loss)
|
||||
self.loss_colocation_ops[loss] = colocation_op
|
||||
|
||||
def _get_use_count_map(self):
|
||||
"""Returns a dict mapping variables to their number of registrations."""
|
||||
return self._vars_to_uses
|
||||
|
||||
def _add_uses(self, params, uses):
|
||||
"""Register additional uses by params in the graph.
|
||||
|
||||
Args:
|
||||
params: Variable or tuple of Variables. Parameters for a layer.
|
||||
uses: int or float. Number of additional uses for these parameters.
|
||||
"""
|
||||
params = params if isinstance(params, (tuple, list)) else (params,)
|
||||
for var in params:
|
||||
self._vars_to_uses[var] += uses
|
||||
|
||||
def check_registration(self, variables):
|
||||
"""Checks that all variable uses have been registered properly.
|
||||
@ -324,7 +384,7 @@ class LayerCollection(object):
|
||||
# Note that overlapping parameters (i.e. those that share variables) will
|
||||
# be caught by layer_collection.LayerParametersDict during registration.
|
||||
|
||||
reg_use_map = self.get_use_count_map()
|
||||
reg_use_map = self._get_use_count_map()
|
||||
|
||||
error_messages = []
|
||||
|
||||
@ -414,12 +474,27 @@ class LayerCollection(object):
|
||||
inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
|
||||
self._subgraph = utils.SubGraph(inputs_to_losses)
|
||||
|
||||
def eval_losses(self):
|
||||
"""Return evaluated losses (colocated with inputs to losses)."""
|
||||
evals = []
|
||||
for loss in self.losses:
|
||||
with ops.colocate_with(self.loss_colocation_ops[loss]):
|
||||
evals.append(loss.evaluate())
|
||||
return evals
|
||||
|
||||
def eval_losses_on_samples(self):
|
||||
"""Return losses evaluated on samples (colocated with inputs to losses)."""
|
||||
evals = []
|
||||
for loss in self.losses:
|
||||
with ops.colocate_with(self.loss_colocation_ops[loss]):
|
||||
evals.append(loss.evaluate_on_sample())
|
||||
return evals
|
||||
|
||||
def total_loss(self):
|
||||
return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses))
|
||||
return math_ops.add_n(self.eval_losses())
|
||||
|
||||
def total_sampled_loss(self):
|
||||
return math_ops.add_n(
|
||||
tuple(loss.evaluate_on_sample() for loss in self.losses))
|
||||
return math_ops.add_n(self.eval_losses_on_samples())
|
||||
|
||||
def _get_linked_approx(self, params):
|
||||
"""If params were linked, return their specified approximation."""
|
||||
@ -469,6 +544,8 @@ class LayerCollection(object):
|
||||
params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
self._add_uses(params, 1)
|
||||
|
||||
def register_fully_connected(self,
|
||||
params,
|
||||
inputs,
|
||||
@ -505,9 +582,12 @@ class LayerCollection(object):
|
||||
block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
|
||||
has_bias = isinstance(params, (tuple, list))
|
||||
|
||||
block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
|
||||
block = self.register_block(params, block_type(self, has_bias=has_bias),
|
||||
reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
self._add_uses(params, 1)
|
||||
|
||||
def register_conv2d(self,
|
||||
params,
|
||||
strides,
|
||||
@ -553,6 +633,8 @@ class LayerCollection(object):
|
||||
params, block_type(self, params, strides, padding), reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
self._add_uses(params, 1)
|
||||
|
||||
def register_generic(self,
|
||||
params,
|
||||
batch_size,
|
||||
@ -586,8 +668,10 @@ class LayerCollection(object):
|
||||
block = self.register_block(params, block_type(self, params), reuse=reuse)
|
||||
block.register_additional_minibatch(batch_size)
|
||||
|
||||
self._add_uses(params, float("inf"))
|
||||
|
||||
def register_fully_connected_multi(self, params, inputs, outputs,
|
||||
approx=None):
|
||||
approx=None, reuse=VARIABLE_SCOPE):
|
||||
"""Register fully connected layers with shared parameters.
|
||||
|
||||
This can handle general fully-connected layers with shared parameters, but
|
||||
@ -604,6 +688,9 @@ class LayerCollection(object):
|
||||
[batch_size, output_size]. Outputs produced by layer. In the case of
|
||||
RNNs, one Tensor per time step.
|
||||
approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2".
|
||||
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
|
||||
create a new FisherBlock. If "VARIABLE_SCOPE", use
|
||||
tf.get_variable_scope().reuse.
|
||||
|
||||
Raises:
|
||||
ValueError: For improper value to 'approx'.
|
||||
@ -621,11 +708,14 @@ class LayerCollection(object):
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx]
|
||||
|
||||
# For now we don't support multiple minibatches for this type of layer, so
|
||||
# we set reuse=False
|
||||
self.register_block(params,
|
||||
block_type(self, inputs, outputs, has_bias=has_bias),
|
||||
reuse=False)
|
||||
block = self.register_block(params, block_type(self, has_bias=has_bias),
|
||||
reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
self._add_uses(params, len(inputs))
|
||||
|
||||
# TODO(b/74108452): change the loss registration functions names to refer
|
||||
# to "loss functions" instead of distributions. Following naming convention
|
||||
# of the loss function classes themselves.
|
||||
|
||||
def register_categorical_predictive_distribution(self,
|
||||
logits,
|
||||
@ -648,50 +738,20 @@ class LayerCollection(object):
|
||||
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
|
||||
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
|
||||
tf.get_variable_scope().reuse.
|
||||
|
||||
Raises:
|
||||
ValueError: If reuse == True and name == None.
|
||||
ValueError: If reuse == True and seed != None.
|
||||
KeyError: If reuse == True and no existing LossFunction with 'name' found.
|
||||
KeyError: If reuse == False and existing LossFunction with 'name' found.
|
||||
"""
|
||||
name = name or self._graph.unique_name(
|
||||
"register_categorical_predictive_distribution")
|
||||
|
||||
if reuse == VARIABLE_SCOPE:
|
||||
reuse = variable_scope.get_variable_scope().reuse
|
||||
|
||||
if reuse:
|
||||
if name is None:
|
||||
raise ValueError(
|
||||
"If reuse is enabled, loss function's name must be set.")
|
||||
if seed is not None:
|
||||
raise ValueError(
|
||||
"Seed can only be specified at LossFunction instantiation.")
|
||||
|
||||
loss = self._loss_dict.get(name, None)
|
||||
|
||||
if loss is None:
|
||||
raise KeyError(
|
||||
"Unable to find loss function named {}. Create a new LossFunction "
|
||||
"with reuse=False.".format(name))
|
||||
|
||||
loss.register_additional_minibatch(logits, targets=targets)
|
||||
else:
|
||||
if name in self._loss_dict:
|
||||
raise KeyError(
|
||||
"Loss function named {} already exists. Set reuse=True to append "
|
||||
"another minibatch.".format(name))
|
||||
loss = lf.CategoricalLogitsNegativeLogProbLoss(
|
||||
logits, targets=targets, seed=seed)
|
||||
self._loss_dict[name] = loss
|
||||
loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
|
||||
seed=seed)
|
||||
self.register_loss_function(loss, logits,
|
||||
"categorical_predictive_distribution",
|
||||
name=name, reuse=reuse)
|
||||
|
||||
def register_normal_predictive_distribution(self,
|
||||
mean,
|
||||
var=0.5,
|
||||
seed=None,
|
||||
targets=None,
|
||||
name=None):
|
||||
name=None,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a normal predictive distribution.
|
||||
|
||||
Args:
|
||||
@ -708,21 +768,22 @@ class LayerCollection(object):
|
||||
(Default: None)
|
||||
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
|
||||
a new name is generated. (Default: None)
|
||||
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
|
||||
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
|
||||
tf.get_variable_scope().reuse.
|
||||
"""
|
||||
name = name or self._graph.unique_name(
|
||||
"register_normal_predictive_distribution")
|
||||
if name in self._loss_dict:
|
||||
raise NotImplementedError(
|
||||
"Adding logits to an existing LossFunction not yet supported.")
|
||||
loss = lf.NormalMeanNegativeLogProbLoss(
|
||||
mean, var, targets=targets, seed=seed)
|
||||
self._loss_dict[name] = loss
|
||||
loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
|
||||
seed=seed)
|
||||
self.register_loss_function(loss, mean,
|
||||
"normal_predictive_distribution",
|
||||
name=name, reuse=reuse)
|
||||
|
||||
def register_multi_bernoulli_predictive_distribution(self,
|
||||
logits,
|
||||
seed=None,
|
||||
targets=None,
|
||||
name=None):
|
||||
name=None,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a multi-Bernoulli predictive distribution.
|
||||
|
||||
Args:
|
||||
@ -735,15 +796,15 @@ class LayerCollection(object):
|
||||
(Default: None)
|
||||
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
|
||||
a new name is generated. (Default: None)
|
||||
reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
|
||||
If False, create a new FisherBlock. If VARIABLE_SCOPE, use
|
||||
tf.get_variable_scope().reuse.
|
||||
"""
|
||||
name = name or self._graph.unique_name(
|
||||
"register_multi_bernoulli_predictive_distribution")
|
||||
if name in self._loss_dict:
|
||||
raise NotImplementedError(
|
||||
"Adding logits to an existing LossFunction not yet supported.")
|
||||
loss = lf.MultiBernoulliNegativeLogProbLoss(
|
||||
logits, targets=targets, seed=seed)
|
||||
self._loss_dict[name] = loss
|
||||
loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
|
||||
seed=seed)
|
||||
self.register_loss_function(loss, logits,
|
||||
"multi_bernoulli_predictive_distribution",
|
||||
name=name, reuse=reuse)
|
||||
|
||||
def make_or_get_factor(self, cls, args):
|
||||
"""Insert 'cls(args)' into 'self.fisher_factors' if not already present.
|
||||
|
@ -57,30 +57,6 @@ class LossFunction(object):
|
||||
"""The inputs to the loss function (excluding the targets)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def input_minibatches(self):
|
||||
"""A `list` of inputs to the loss function, separated by minibatch.
|
||||
|
||||
Typically there will be one minibatch per tower in a multi-tower setup.
|
||||
Returns a list consisting of `self.inputs` by default; `LossFunction`s
|
||||
supporting registering multiple minibatches should override this method.
|
||||
|
||||
Returns:
|
||||
A `list` of `Tensor`s representing
|
||||
"""
|
||||
return [self.inputs]
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
"""Number of minibatches registered for this LossFunction.
|
||||
|
||||
Typically equal to the number of towers in a multi-tower setup.
|
||||
|
||||
Returns:
|
||||
An `int` representing the number of registered minibatches.
|
||||
"""
|
||||
return len(self.input_minibatches)
|
||||
|
||||
def evaluate(self):
|
||||
"""Evaluate the loss function on the targets."""
|
||||
if self.targets is not None:
|
||||
@ -474,7 +450,6 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
|
||||
assert len(variance.shape) == 2, "Expect 2D variance tensor."
|
||||
self._mean = mean
|
||||
self._variance = variance
|
||||
self._scale = math_ops.sqrt(variance)
|
||||
self._targets = targets
|
||||
super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
@ -484,7 +459,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return normal.Normal(loc=self._mean, scale=self._scale)
|
||||
return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
@ -502,7 +477,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
|
||||
|
||||
@property
|
||||
def _fisher_mean_factor(self):
|
||||
return 1. / self._scale
|
||||
return 1. / math_ops.sqrt(self._variance)
|
||||
|
||||
@property
|
||||
def _fisher_var(self):
|
||||
@ -611,36 +586,13 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
|
||||
index in [0, output_size).
|
||||
seed: int or None. Default random seed when sampling.
|
||||
"""
|
||||
self._logits_components = []
|
||||
self._targets_components = []
|
||||
self.register_additional_minibatch(logits, targets=targets)
|
||||
self._logits = logits
|
||||
self._targets = targets
|
||||
super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
def register_additional_minibatch(self, logits, targets=None):
|
||||
"""Register an additiona minibatch's worth of parameters.
|
||||
|
||||
Args:
|
||||
logits: Tensor of shape [batch_size, output_size]. Parameters for
|
||||
underlying distribution.
|
||||
targets: None or Tensor of shape [batch_size, output_size]. Each row must
|
||||
be a one-hot vector.
|
||||
"""
|
||||
self._logits_components.append(logits)
|
||||
self._targets_components.append(targets)
|
||||
|
||||
@property
|
||||
def _logits(self):
|
||||
return array_ops.concat(self._logits_components, axis=0)
|
||||
|
||||
@property
|
||||
def input_minibatches(self):
|
||||
return self._logits_components
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
if all(target is None for target in self._targets_components):
|
||||
return None
|
||||
return array_ops.concat(self._targets_components, axis=0)
|
||||
return self._targets
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
|
||||
# pylint disable=long-line
|
||||
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
|
||||
from tensorflow.contrib.kfac.python.ops import estimator as est
|
||||
@ -50,6 +52,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
name="KFAC",
|
||||
estimation_mode="gradients",
|
||||
colocate_gradients_with_ops=True,
|
||||
batch_size=None,
|
||||
cov_devices=None,
|
||||
inv_devices=None):
|
||||
"""Initializes the KFAC optimizer with the given settings.
|
||||
@ -91,12 +94,16 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
colocate_gradients_with_ops: Whether we should request gradients we
|
||||
compute in the estimator be colocated with their respective ops.
|
||||
(Default: True)
|
||||
batch_size: The size of the mini-batch. Only needed when momentum_type
|
||||
== 'qmodel' or when automatic adjustment is used. (Default: None)
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
Can be None, which means that no devices are specified. Only used
|
||||
with (soon-to-be-depcrecated "convenience" properties).
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
Can be None, which means that no devices are specified. Only used
|
||||
with (soon-to-be-depcrecated "convenience" properties).
|
||||
|
||||
Raises:
|
||||
ValueError: If the momentum type is unsupported.
|
||||
@ -110,6 +117,15 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
if variables is None:
|
||||
variables = tf_variables.trainable_variables()
|
||||
|
||||
# Parameters to be passed to the Fisher estimator:
|
||||
self._variables = variables
|
||||
self._cov_ema_decay = cov_ema_decay
|
||||
self._layers = layer_collection
|
||||
self._estimation_mode = estimation_mode
|
||||
self._colocate_gradients_with_ops = colocate_gradients_with_ops
|
||||
self._cov_devices = cov_devices
|
||||
self._inv_devices = inv_devices
|
||||
|
||||
# The below paramaters are required only if damping needs to be adapated.
|
||||
# These parameters can be set by calling
|
||||
# set_damping_adaptation_params() explicitly.
|
||||
@ -130,17 +146,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
self._q_model_change = None
|
||||
self._update_damping_op = None
|
||||
|
||||
self._layers = layer_collection
|
||||
self._fisher_est = est.FisherEstimator(
|
||||
lambda: self.damping,
|
||||
variables,
|
||||
cov_ema_decay,
|
||||
layer_collection,
|
||||
estimation_mode=estimation_mode,
|
||||
colocate_gradients_with_ops=colocate_gradients_with_ops,
|
||||
cov_devices=cov_devices,
|
||||
inv_devices=inv_devices)
|
||||
|
||||
momentum_type = momentum_type.lower()
|
||||
legal_momentum_types = ["regular", "adam", "qmodel"]
|
||||
|
||||
@ -154,14 +159,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
raise ValueError("Momentum must be unspecified if using a momentum_type "
|
||||
"other than 'regular' or 'adam'.")
|
||||
|
||||
# Extra parameters of the optimizer
|
||||
self._momentum = momentum
|
||||
self._momentum_type = momentum_type
|
||||
self._norm_constraint = norm_constraint
|
||||
self._batch_size = batch_size
|
||||
|
||||
# this is a bit of a hack
|
||||
# TODO(duckworthd): Handle this in a better way (e.g. pass it in?)
|
||||
self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0]
|
||||
self._losses = layer_collection.losses
|
||||
with variable_scope.variable_scope(name):
|
||||
self._fisher_est = est.FisherEstimator(
|
||||
self._variables,
|
||||
self._cov_ema_decay,
|
||||
self.damping,
|
||||
self._layers,
|
||||
exps=(-1,),
|
||||
estimation_mode=self._estimation_mode,
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
|
||||
super(KfacOptimizer, self).__init__(learning_rate, name=name)
|
||||
|
||||
@ -178,6 +190,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
style rule described in Section 6.5 of "Optimizing Neural Networks with
|
||||
Kronecker-factored Approximate Curvature".
|
||||
|
||||
Note that this function creates Tensorflow variables which store a few
|
||||
scalars and are accessed by the ops which update the damping (as part
|
||||
of the training op returned by the minimize() method).
|
||||
|
||||
Args:
|
||||
is_chief: `Boolean`, `True` if the worker is chief.
|
||||
prev_train_batch: Training data used to minimize loss in the previous
|
||||
@ -199,6 +215,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
"""
|
||||
if self._adapt_damping:
|
||||
raise ValueError("Damping adaptation parameters already set.")
|
||||
|
||||
with variable_scope.variable_scope(self.get_name()):
|
||||
self._adapt_damping = True
|
||||
self._is_chief = is_chief
|
||||
@ -221,31 +238,37 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
|
||||
@property
|
||||
def cov_update_thunks(self):
|
||||
return self._fisher_est.cov_update_thunks
|
||||
self._maybe_make_and_save_everything()
|
||||
return self._cov_update_thunks
|
||||
|
||||
@property
|
||||
def cov_update_ops(self):
|
||||
return self._fisher_est.cov_update_ops
|
||||
self._maybe_make_and_save_everything()
|
||||
return self._cov_update_ops
|
||||
|
||||
@property
|
||||
def cov_update_op(self):
|
||||
return self._fisher_est.cov_update_op
|
||||
self._maybe_make_and_save_everything()
|
||||
return self._cov_update_op
|
||||
|
||||
@property
|
||||
def inv_update_thunks(self):
|
||||
return self._fisher_est.inv_update_thunks
|
||||
self._maybe_make_and_save_everything()
|
||||
return self._inv_update_thunks
|
||||
|
||||
@property
|
||||
def inv_update_ops(self):
|
||||
return self._fisher_est.inv_update_ops
|
||||
self._maybe_make_and_save_everything()
|
||||
return self._inv_update_ops
|
||||
|
||||
@property
|
||||
def inv_update_op(self):
|
||||
return self._fisher_est.inv_update_op
|
||||
self._maybe_make_and_save_everything()
|
||||
return self._inv_update_op
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._fisher_est.variables
|
||||
return self._variables
|
||||
|
||||
@property
|
||||
def damping(self):
|
||||
@ -258,25 +281,162 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
def damping_adaptation_interval(self):
|
||||
return self._damping_adaptation_interval
|
||||
|
||||
def _maybe_make_and_save_everything(self):
|
||||
if not self._fisher_est.made_vars():
|
||||
warnings.warn("These convenience properties will be depcrecated soon. "
|
||||
"Please use explicit op/thunk creation methods instead "
|
||||
"(e.g. make_ops_and_vars_round_robin, etc).",
|
||||
DeprecationWarning)
|
||||
(self._cov_update_ops, self._cov_update_op, self._inv_update_ops,
|
||||
self._inv_update_op, self._cov_update_thunks,
|
||||
self._inv_update_thunks) = self.make_ops_and_vars_round_robin(
|
||||
cov_devices=self._cov_devices,
|
||||
inv_devices=self._inv_devices)
|
||||
|
||||
def make_ops_and_vars(self):
|
||||
"""Make ops and vars with no specific device placement.
|
||||
|
||||
See make_ops_and_vars_round_robin for details.
|
||||
|
||||
Returns:
|
||||
cov_update_ops: List of ops that compute the cov updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
cov_update_op: cov_update_ops grouped into a single op.
|
||||
inv_update_ops: List of ops that compute the inv updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
cov_update_op: cov_update_ops grouped into a single op.
|
||||
inv_update_op: inv_update_ops grouped into a single op.
|
||||
"""
|
||||
with variable_scope.variable_scope(self.get_name()):
|
||||
return self._fisher_est.make_ops_and_vars()
|
||||
|
||||
def make_ops_and_vars_round_robin(self, cov_devices=None, inv_devices=None):
|
||||
"""Make ops and vars with a round-robin device placement strategy.
|
||||
|
||||
For each factor, all of that factor's cov variables and their associated
|
||||
update ops will be placed on a particular device. A new device is chosen
|
||||
for each factor by cycling through list of devices in the cov_devices
|
||||
argument. If cov_devices is None then no explicit device placement occurs.
|
||||
|
||||
An analogous strategy is followed for inverse update ops, with the list of
|
||||
devices being given by the inv_devices argument.
|
||||
|
||||
Inverse variables on the other hand are not placed on any specific device
|
||||
(they will just use the current the device placement context, whatever
|
||||
that happens to be). The idea is that the inverse variable belong where
|
||||
they will be accessed most often, which is the device that actually applies
|
||||
the preconditioner to the gradient. The user will be responsible for setting
|
||||
the device context for this.
|
||||
|
||||
Args:
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
|
||||
Returns:
|
||||
cov_update_ops: List of ops that compute the cov updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
cov_update_op: cov_update_ops grouped into a single op.
|
||||
inv_update_ops: List of ops that compute the inv updates. Corresponds
|
||||
one-to-one with the list of factors given by the "factors" property.
|
||||
cov_update_op: cov_update_ops grouped into a single op.
|
||||
inv_update_op: inv_update_ops grouped into a single op.
|
||||
cov_update_thunks: Thunks that make the ops in cov_update_ops.
|
||||
inv_update_thunks: Thunks that make the ops in inv_update_ops.
|
||||
"""
|
||||
with variable_scope.variable_scope(self.get_name()):
|
||||
return self._fisher_est.make_ops_and_vars_round_robin(
|
||||
cov_devices=cov_devices, inv_devices=inv_devices)
|
||||
|
||||
def make_vars_and_create_op_thunks_round_robin(self,
|
||||
cov_devices=None,
|
||||
inv_devices=None):
|
||||
"""Make vars and create op thunks w/ a round-robin device placement strat.
|
||||
|
||||
For each factor, all of that factor's cov variables and their associated
|
||||
update ops will be placed on a particular device. A new device is chosen
|
||||
for each factor by cycling through list of devices in the cov_devices
|
||||
argument. If cov_devices is None then no explicit device placement occurs.
|
||||
|
||||
An analogous strategy is followed for inverse update ops, with the list of
|
||||
devices being given by the inv_devices argument.
|
||||
|
||||
Inverse variables on the other hand are not placed on any specific device
|
||||
(they will just use the current the device placement context, whatever
|
||||
that happens to be). The idea is that the inverse variable belong where
|
||||
they will be accessed most often, which is the device that actually applies
|
||||
the preconditioner to the gradient. The user will be responsible for setting
|
||||
the device context for this.
|
||||
|
||||
Args:
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
Returns:
|
||||
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
"""
|
||||
scope = self.get_name() + "/" + self._fisher_est.name
|
||||
return self._fisher_est.make_vars_and_create_op_thunks_round_robin(
|
||||
scope=scope, cov_devices=cov_devices, inv_devices=inv_devices)
|
||||
|
||||
def ops_and_vars_thunks(self):
|
||||
"""Create thunks that make the ops and vars on demand.
|
||||
|
||||
This function returns 4 lists of thunks: cov_variable_thunks,
|
||||
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
|
||||
|
||||
The length of each list is the number of factors and the i-th element of
|
||||
each list corresponds to the i-th factor (given by the "factors" property).
|
||||
|
||||
Note that the execution of these thunks must happen in a certain
|
||||
partial order. The i-th element of cov_variable_thunks must execute
|
||||
before the i-th element of cov_update_thunks (and also the i-th element
|
||||
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
|
||||
must execute before the i-th element of inv_update_thunks.
|
||||
|
||||
TL;DR (oversimplified): Execute the thunks according to the order that
|
||||
they are returned.
|
||||
|
||||
Returns:
|
||||
cov_variable_thunks: A list of thunks that make the cov variables.
|
||||
cov_update_thunks: A list of thunks that make the cov update ops.
|
||||
inv_variable_thunks: A list of thunks that make the inv variables.
|
||||
inv_update_thunks: A list of thunks that make the inv update ops.
|
||||
"""
|
||||
scope = self.get_name() + "/" + self._fisher_est.name
|
||||
return self._fisher_est.ops_and_vars_thunks(scope=scope)
|
||||
|
||||
def minimize(self, *args, **kwargs):
|
||||
kwargs["var_list"] = kwargs.get("var_list") or self.variables
|
||||
if set(kwargs["var_list"]) != set(self.variables):
|
||||
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
||||
"variables.")
|
||||
if self._adapt_damping and self._is_chief:
|
||||
global_step = kwargs.get("global_step", None)
|
||||
if not global_step:
|
||||
raise KeyError("global_step needs to be passed to optimizer.minimize "
|
||||
"if damping parameter is adapted.")
|
||||
update_damping_op = self._update_damping(self._prev_train_batch,
|
||||
global_step)
|
||||
with ops.control_dependencies([update_damping_op]):
|
||||
loss = args[0]
|
||||
loss_assign_op = state_ops.assign(self._prev_loss, loss)
|
||||
train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||
return control_flow_ops.group(loss_assign_op, train_op)
|
||||
else:
|
||||
return super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||
# Should this variable scope encompass everything below? Or will the super-
|
||||
# class make another copy of the same name scope?
|
||||
with variable_scope.variable_scope(self.get_name()):
|
||||
kwargs["var_list"] = kwargs.get("var_list") or self.variables
|
||||
if set(kwargs["var_list"]) != set(self.variables):
|
||||
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
||||
"variables.")
|
||||
if self._adapt_damping and self._is_chief:
|
||||
global_step = kwargs.get("global_step", None)
|
||||
if not global_step:
|
||||
raise KeyError("global_step needs to be passed to optimizer.minimize "
|
||||
"if damping parameter is adapted.")
|
||||
update_damping_op = self._update_damping(self._prev_train_batch,
|
||||
global_step)
|
||||
with ops.control_dependencies([update_damping_op]):
|
||||
loss = args[0]
|
||||
loss_assign_op = state_ops.assign(self._prev_loss, loss)
|
||||
train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||
return control_flow_ops.group(loss_assign_op, train_op)
|
||||
else:
|
||||
return super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||
|
||||
def compute_gradients(self, *args, **kwargs):
|
||||
# args[1] could be our var_list
|
||||
@ -301,6 +461,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
Returns:
|
||||
An `Operation` that applies the specified gradients.
|
||||
"""
|
||||
self._maybe_make_and_save_everything()
|
||||
|
||||
# In Python 3, grads_and_vars can be a zip() object which can only be
|
||||
# iterated over once. By converting it to a list, we ensure that it can be
|
||||
# iterated over more than once.
|
||||
@ -450,7 +612,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
= qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
|
||||
"""
|
||||
|
||||
cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables)
|
||||
cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
|
||||
variables)
|
||||
|
||||
# compute the matrix-vector products with the transposed Fisher factor
|
||||
fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -482,5 +483,76 @@ def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
|
||||
a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
|
||||
return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
|
||||
|
||||
|
||||
class PartitionedTensor(object):
|
||||
"""A Tensor partitioned across its 0-th dimension."""
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""Initializes PartitionedTensor.
|
||||
|
||||
Args:
|
||||
tensors: List of Tensors. All Tensors must agree on shape (excepting
|
||||
batch dimension) and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'tensors' has length zero.
|
||||
ValueError: if contents of 'tensors' don't agree on shape or dtype.
|
||||
"""
|
||||
if not tensors:
|
||||
raise ValueError("tensors must be a list of 1+ Tensors.")
|
||||
|
||||
dtype = tensors[0].dtype
|
||||
if not all(tensor.dtype == dtype for tensor in tensors):
|
||||
raise ValueError("all tensors must have dtype = %s." % dtype)
|
||||
|
||||
shape = tensors[0].shape[1:]
|
||||
if not all(tensor.shape[1:] == shape for tensor in tensors):
|
||||
raise ValueError("All tensors must have shape = %s (excluding batch "
|
||||
"dimension)." % shape)
|
||||
|
||||
self.tensors = tensors
|
||||
self._concats = {} # {device: Tensor}
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
feature_shape = self.tensors[0].shape[1:]
|
||||
batch_size = sum([tensor.shape[0] for tensor in self.tensors],
|
||||
tensor_shape.Dimension(0))
|
||||
return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
def devices(self):
|
||||
return set(tensor.device for tensor in self.tensors)
|
||||
|
||||
def __str__(self):
|
||||
return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
|
||||
self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.tensors))
|
||||
|
||||
def as_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
|
||||
assert not as_ref
|
||||
assert dtype in [None, self.dtype]
|
||||
result = array_ops.concat(self.tensors, axis=0)
|
||||
|
||||
# Cache 'result' if we haven't already cached a value for this device.
|
||||
if result.device not in self._concats:
|
||||
self._concats[result.device] = result
|
||||
return self._concats[result.device]
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
PartitionedTensor,
|
||||
lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
|
||||
|
||||
|
||||
# TODO(b/69623235): Add a function for finding tensors that share gradients
|
||||
# to eliminate redundant fisher factor computations.
|
||||
|
Loading…
x
Reference in New Issue
Block a user