- Adding support for Cholesky (inverse) factor multiplications.
- Refactored FisherFactor to use LinearOperator classes that know how to multiply themselves, compute their own trace, etc. This addresses the feature request: b/73356352 - Fixed some problems with FisherEstimator construction - More careful casting of damping constants before they are used PiperOrigin-RevId: 194379298
This commit is contained in:
parent
8148895adc
commit
481f229881
@ -58,6 +58,7 @@ py_test(
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||
"//tensorflow/contrib/kfac/python/ops:linear_operator",
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
||||
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
|
||||
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -46,8 +47,9 @@ class UtilsTest(test.TestCase):
|
||||
def testComputePiTracenorm(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
left_factor = array_ops.diag([1., 2., 0., 1.])
|
||||
right_factor = array_ops.ones([2., 2.])
|
||||
diag = ops.convert_to_tensor([1., 2., 0., 1.])
|
||||
left_factor = lo.LinearOperatorDiag(diag)
|
||||
right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
|
||||
|
||||
# pi is the sqrt of the left trace norm divided by the right trace norm
|
||||
pi = fb.compute_pi_tracenorm(left_factor, right_factor)
|
||||
@ -245,7 +247,6 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
|
||||
full = sess.run(block.full_fisher_block())
|
||||
explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
|
||||
|
||||
self.assertAllClose(output_flat, explicit)
|
||||
|
||||
|
||||
|
@ -70,18 +70,6 @@ class FisherFactorTestingDummy(ff.FisherFactor):
|
||||
def get_cov(self):
|
||||
return NotImplementedError
|
||||
|
||||
def left_multiply(self, x, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def right_multiply(self, x, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def left_multiply_matpower(self, x, exp, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def right_multiply_matpower(self, x, exp, damping):
|
||||
return NotImplementedError
|
||||
|
||||
def instantiate_inv_variables(self):
|
||||
return NotImplementedError
|
||||
|
||||
@ -91,14 +79,35 @@ class FisherFactorTestingDummy(ff.FisherFactor):
|
||||
def _get_data_device(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def register_matpower(self, exp, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
|
||||
"""Dummy class to test the non-abstract methods on ff.InverseProvidingFactor.
|
||||
def register_cholesky(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def register_cholesky_inverse(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_matpower(self, exp, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cholesky(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cholesky_inverse(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cov_as_linear_operator(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
|
||||
"""Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
|
||||
"""
|
||||
|
||||
def __init__(self, shape):
|
||||
self._shape = shape
|
||||
super(InverseProvidingFactorTestingDummy, self).__init__()
|
||||
super(DenseSquareMatrixFactorTestingDummy, self).__init__()
|
||||
|
||||
@property
|
||||
def _var_scope(self):
|
||||
@ -230,13 +239,13 @@ class FisherFactorTest(test.TestCase):
|
||||
self.assertEqual(0, len(factor.make_inverse_update_ops()))
|
||||
|
||||
|
||||
class InverseProvidingFactorTest(test.TestCase):
|
||||
class DenseSquareMatrixFactorTest(test.TestCase):
|
||||
|
||||
def testRegisterDampedInverse(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
shape = [2, 2]
|
||||
factor = InverseProvidingFactorTestingDummy(shape)
|
||||
factor = DenseSquareMatrixFactorTestingDummy(shape)
|
||||
factor_var_scope = 'dummy/a_b_c'
|
||||
|
||||
damping_funcs = [make_damping_func(0.1),
|
||||
@ -248,22 +257,25 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
|
||||
factor.instantiate_inv_variables()
|
||||
|
||||
inv = factor.get_inverse(damping_funcs[0])
|
||||
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
|
||||
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
|
||||
self.assertEqual(factor.get_inverse(damping_funcs[2]),
|
||||
factor.get_inverse(damping_funcs[3]))
|
||||
inv = factor.get_inverse(damping_funcs[0]).to_dense()
|
||||
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
|
||||
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
|
||||
self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
|
||||
factor.get_inverse(damping_funcs[3]).to_dense())
|
||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
factor_var_scope)
|
||||
self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
|
||||
set(factor_vars))
|
||||
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
|
||||
|
||||
self.assertEqual(set([inv,
|
||||
factor.get_inverse(damping_funcs[2]).to_dense()]),
|
||||
set(factor_tensors))
|
||||
self.assertEqual(shape, inv.get_shape())
|
||||
|
||||
def testRegisterMatpower(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
shape = [3, 3]
|
||||
factor = InverseProvidingFactorTestingDummy(shape)
|
||||
factor = DenseSquareMatrixFactorTestingDummy(shape)
|
||||
factor_var_scope = 'dummy/a_b_c'
|
||||
|
||||
# TODO(b/74201126): Change to using the same func for both once
|
||||
@ -278,10 +290,13 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
|
||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
factor_var_scope)
|
||||
matpower1 = factor.get_matpower(-0.5, damping_func_1)
|
||||
matpower2 = factor.get_matpower(2, damping_func_2)
|
||||
|
||||
self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
|
||||
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
|
||||
|
||||
matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
|
||||
matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
|
||||
|
||||
self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
|
||||
|
||||
self.assertEqual(shape, matpower1.get_shape())
|
||||
self.assertEqual(shape, matpower2.get_shape())
|
||||
@ -297,7 +312,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
cov = np.array([[1., 2.], [3., 4.]])
|
||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
||||
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
|
||||
damping_funcs = []
|
||||
@ -316,7 +331,8 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
sess.run(ops)
|
||||
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||
new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
|
||||
new_invs.append(
|
||||
sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
|
||||
|
||||
# We want to see that the new invs are all different from each other.
|
||||
for i in range(len(new_invs)):
|
||||
@ -328,7 +344,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
cov = np.array([[6., 2.], [2., 4.]])
|
||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
||||
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
|
||||
damping = 0.5
|
||||
@ -341,7 +357,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(ops[0])
|
||||
matpower = sess.run(factor.get_matpower(exp, damping_func))
|
||||
matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
|
||||
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
|
||||
self.assertAllClose(matpower, matpower_np)
|
||||
|
||||
@ -349,7 +365,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
|
||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
||||
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
|
||||
damping_func = make_damping_func(0)
|
||||
@ -361,12 +377,12 @@ class InverseProvidingFactorTest(test.TestCase):
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||
old_inv = sess.run(factor.get_inverse(damping_func))
|
||||
old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
|
||||
self.assertAllClose(
|
||||
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
|
||||
|
||||
sess.run(ops)
|
||||
new_inv = sess.run(factor.get_inverse(damping_func))
|
||||
new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
|
||||
self.assertAllClose(new_inv, np.linalg.inv(cov))
|
||||
|
||||
|
||||
@ -411,7 +427,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
|
||||
self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testNaiveDiagonalFactorInitFloat64(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
@ -420,7 +436,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov_var()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([6, 1], cov.get_shape().as_list())
|
||||
|
||||
@ -444,7 +460,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
|
||||
vocab_size = 5
|
||||
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov_var()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.shape.as_list(), [vocab_size])
|
||||
|
||||
def testCovarianceUpdateOp(self):
|
||||
@ -502,7 +518,7 @@ class ConvDiagonalFactorTest(test.TestCase):
|
||||
self.kernel_height * self.kernel_width * self.in_channels,
|
||||
self.out_channels
|
||||
],
|
||||
factor.get_cov_var().shape.as_list())
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
@ -564,7 +580,7 @@ class ConvDiagonalFactorTest(test.TestCase):
|
||||
self.kernel_height * self.kernel_width * self.in_channels + 1,
|
||||
self.out_channels
|
||||
],
|
||||
factor.get_cov_var().shape.as_list())
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
# Ensure update op doesn't crash.
|
||||
cov_update_op = factor.make_covariance_update_op(0.0)
|
||||
@ -654,13 +670,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
||||
# Ensure shape of covariance matches input size of filter.
|
||||
input_size = in_channels * (width**3)
|
||||
self.assertEqual([input_size, input_size],
|
||||
factor.get_cov_var().shape.as_list())
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
# Ensure cov_update_op doesn't crash.
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov_var())
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank-8, as the filter will be applied at each corner of
|
||||
# the 4-D cube.
|
||||
@ -685,13 +701,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
||||
|
||||
# Ensure shape of covariance matches input size of filter.
|
||||
self.assertEqual([in_channels, in_channels],
|
||||
factor.get_cov_var().shape.as_list())
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
# Ensure cov_update_op doesn't crash.
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov_var())
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank-9, as the filter will be applied at each location.
|
||||
self.assertMatrixRank(9, cov)
|
||||
@ -716,7 +732,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov_var())
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be the sum of 3 * 2 = 6 outer products.
|
||||
self.assertMatrixRank(6, cov)
|
||||
@ -742,7 +758,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov_var())
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank = in_channels, as only the center of the filter
|
||||
# receives non-zero input for each input channel.
|
||||
|
@ -35,6 +35,7 @@ py_library(
|
||||
srcs = ["fisher_factors.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":linear_operator",
|
||||
":utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
@ -63,6 +64,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "linear_operator",
|
||||
srcs = ["linear_operator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
|
@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs):
|
||||
if placement_strategy in [None, "round_robin"]:
|
||||
return FisherEstimatorRoundRobin(**kwargs)
|
||||
else:
|
||||
raise ValueError("Unimplemented vars and ops placement strategy : %s",
|
||||
placement_strategy)
|
||||
raise ValueError("Unimplemented vars and ops "
|
||||
"placement strategy : {}".format(placement_strategy))
|
||||
# pylint: enable=abstract-class-instantiated
|
||||
|
||||
|
||||
@ -81,7 +81,9 @@ class FisherEstimator(object):
|
||||
exps=(-1,),
|
||||
estimation_mode="gradients",
|
||||
colocate_gradients_with_ops=True,
|
||||
name="FisherEstimator"):
|
||||
name="FisherEstimator",
|
||||
compute_cholesky=False,
|
||||
compute_cholesky_inverse=False):
|
||||
"""Create a FisherEstimator object.
|
||||
|
||||
Args:
|
||||
@ -124,6 +126,12 @@ class FisherEstimator(object):
|
||||
name: A string. A name given to this estimator, which is added to the
|
||||
variable scope when constructing variables and ops.
|
||||
(Default: "FisherEstimator")
|
||||
compute_cholesky: Bool. Whether or not the FisherEstimator will be
|
||||
able to multiply vectors by the Cholesky factor.
|
||||
(Default: False)
|
||||
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
|
||||
will be able to multiply vectors by the Cholesky factor inverse.
|
||||
(Default: False)
|
||||
Raises:
|
||||
ValueError: If no losses have been registered with layer_collection.
|
||||
"""
|
||||
@ -142,6 +150,8 @@ class FisherEstimator(object):
|
||||
|
||||
self._made_vars = False
|
||||
self._exps = exps
|
||||
self._compute_cholesky = compute_cholesky
|
||||
self._compute_cholesky_inverse = compute_cholesky_inverse
|
||||
|
||||
self._name = name
|
||||
|
||||
@ -300,9 +310,54 @@ class FisherEstimator(object):
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
assert exp in self._exps
|
||||
|
||||
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def multiply_cholesky(self, vecs_and_vars, transpose=False):
|
||||
"""Multiplies the vecs by the corresponding Cholesky factors.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
transpose: Bool. If true the Cholesky factors are transposed before
|
||||
multiplying the vecs. (Default: False)
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
assert self._compute_cholesky
|
||||
|
||||
fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
|
||||
"""Mults the vecs by the inverses of the corresponding Cholesky factors.
|
||||
|
||||
Note: if you are using Cholesky inverse multiplication to sample from
|
||||
a matrix-variate Gaussian you will want to multiply by the transpose.
|
||||
Let L be the Cholesky factor of F and observe that
|
||||
|
||||
L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
|
||||
|
||||
Thus we want to multiply by L^-T in order to sample from Gaussian with
|
||||
covariance F^-1.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
transpose: Bool. If true the Cholesky factor inverses are transposed
|
||||
before multiplying the vecs. (Default: False)
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
assert self._compute_cholesky_inverse
|
||||
|
||||
fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def _instantiate_factors(self):
|
||||
"""Instantiates FisherFactors' variables.
|
||||
|
||||
@ -333,9 +388,13 @@ class FisherEstimator(object):
|
||||
return self._made_vars
|
||||
|
||||
def _register_matrix_functions(self):
|
||||
for exp in self._exps:
|
||||
for block in self.blocks:
|
||||
for block in self.blocks:
|
||||
for exp in self._exps:
|
||||
block.register_matpower(exp)
|
||||
if self._compute_cholesky:
|
||||
block.register_cholesky()
|
||||
if self._compute_cholesky_inverse:
|
||||
block.register_cholesky_inverse()
|
||||
|
||||
def _finalize_layer_collection(self):
|
||||
self._layers.create_subgraph()
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'FisherEstimator',
|
||||
'make_fisher_estimator',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
||||
|
@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications):
|
||||
|
||||
|
||||
def compute_pi_tracenorm(left_cov, right_cov):
|
||||
"""Computes the scalar constant pi for Tikhonov regularization/damping.
|
||||
r"""Computes the scalar constant pi for Tikhonov regularization/damping.
|
||||
|
||||
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
|
||||
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
|
||||
|
||||
Args:
|
||||
left_cov: The left Kronecker factor "covariance".
|
||||
right_cov: The right Kronecker factor "covariance".
|
||||
left_cov: A LinearOperator object. The left Kronecker factor "covariance".
|
||||
right_cov: A LinearOperator object. The right Kronecker factor "covariance".
|
||||
|
||||
Returns:
|
||||
The computed scalar constant pi for these Kronecker Factors (as a Tensor).
|
||||
"""
|
||||
|
||||
def _trace(cov):
|
||||
if len(cov.shape) == 1:
|
||||
# Diagonal matrix.
|
||||
return math_ops.reduce_sum(cov)
|
||||
elif len(cov.shape) == 2:
|
||||
# Full matrix.
|
||||
return math_ops.trace(cov)
|
||||
else:
|
||||
raise ValueError(
|
||||
"What's the trace of a Tensor of rank %d?" % len(cov.shape))
|
||||
|
||||
# Instead of dividing by the dim of the norm, we multiply by the dim of the
|
||||
# other norm. This works out the same in the ratio.
|
||||
left_norm = _trace(left_cov) * right_cov.shape.as_list()[0]
|
||||
right_norm = _trace(right_cov) * left_cov.shape.as_list()[0]
|
||||
left_norm = left_cov.trace() * int(right_cov.domain_dimension)
|
||||
right_norm = right_cov.trace() * int(left_cov.domain_dimension)
|
||||
return math_ops.sqrt(left_norm / right_norm)
|
||||
|
||||
|
||||
@ -188,6 +176,16 @@ class FisherBlock(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def register_cholesky(self):
|
||||
"""Registers a Cholesky factor to be computed by the block."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def register_cholesky_inverse(self):
|
||||
"""Registers an inverse Cholesky factor to be computed by the block."""
|
||||
pass
|
||||
|
||||
def register_inverse(self):
|
||||
"""Registers a matrix inverse to be computed by the block."""
|
||||
self.register_matpower(-1)
|
||||
@ -228,6 +226,33 @@ class FisherBlock(object):
|
||||
"""
|
||||
return self.multiply_matpower(vector, 1)
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_cholesky(self, vector, transpose=False):
|
||||
"""Multiplies the vector by the (damped) Cholesky-factor of the block.
|
||||
|
||||
Args:
|
||||
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
|
||||
transpose: Bool. If true the Cholesky factor is transposed before
|
||||
multiplying the vector. (Default: False)
|
||||
|
||||
Returns:
|
||||
The vector left-multiplied by the (damped) Cholesky-factor of the block.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||
"""Multiplies vector by the (damped) inverse Cholesky-factor of the block.
|
||||
|
||||
Args:
|
||||
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
|
||||
transpose: Bool. If true the Cholesky factor inverse is transposed
|
||||
before multiplying the vector. (Default: False)
|
||||
Returns:
|
||||
Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def tensors_to_compute_grads(self):
|
||||
"""Returns the Tensor(s) with respect to which this FisherBlock needs grads.
|
||||
@ -275,15 +300,32 @@ class FullFB(FisherBlock):
|
||||
def register_matpower(self, exp):
|
||||
self._factor.register_matpower(exp, self._damping_func)
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
def register_cholesky(self):
|
||||
self._factor.register_cholesky(self._damping_func)
|
||||
|
||||
def register_cholesky_inverse(self):
|
||||
self._factor.register_cholesky_inverse(self._damping_func)
|
||||
|
||||
def _multiply_matrix(self, matrix, vector, transpose=False):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
out_flat = self._factor.left_multiply_matpower(
|
||||
vector_flat, exp, self._damping_func)
|
||||
out_flat = matrix.matmul(vector_flat, adjoint=transpose)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
matrix = self._factor.get_matpower(exp, self._damping_func)
|
||||
return self._multiply_matrix(matrix, vector)
|
||||
|
||||
def multiply_cholesky(self, vector, transpose=False):
|
||||
matrix = self._factor.get_cholesky(self._damping_func)
|
||||
return self._multiply_matrix(matrix, vector, transpose=transpose)
|
||||
|
||||
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||
matrix = self._factor.get_cholesky_inverse(self._damping_func)
|
||||
return self._multiply_matrix(matrix, vector, transpose=transpose)
|
||||
|
||||
def full_fisher_block(self):
|
||||
"""Explicitly constructs the full Fisher block."""
|
||||
return self._factor.get_cov()
|
||||
return self._factor.get_cov_as_linear_operator().to_dense()
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._params
|
||||
@ -305,7 +347,47 @@ class FullFB(FisherBlock):
|
||||
return math_ops.reduce_sum(self._batch_sizes)
|
||||
|
||||
|
||||
class NaiveDiagonalFB(FisherBlock):
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class DiagonalFB(FisherBlock):
|
||||
"""A base class for FisherBlocks that use diagonal approximations."""
|
||||
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def register_cholesky(self):
|
||||
# Not needed for this. Cholesky's are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def register_cholesky_inverse(self):
|
||||
# Not needed for this. Cholesky inverses's are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def _multiply_matrix(self, matrix, vector):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
out_flat = matrix.matmul(vector_flat)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
matrix = self._factor.get_matpower(exp, self._damping_func)
|
||||
return self._multiply_matrix(matrix, vector)
|
||||
|
||||
def multiply_cholesky(self, vector, transpose=False):
|
||||
matrix = self._factor.get_cholesky(self._damping_func)
|
||||
return self._multiply_matrix(matrix, vector)
|
||||
|
||||
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||
matrix = self._factor.get_cholesky_inverse(self._damping_func)
|
||||
return self._multiply_matrix(matrix, vector)
|
||||
|
||||
def full_fisher_block(self):
|
||||
return self._factor.get_cov_as_linear_operator().to_dense()
|
||||
|
||||
|
||||
class NaiveDiagonalFB(DiagonalFB):
|
||||
"""FisherBlock using a diagonal matrix approximation.
|
||||
|
||||
This type of approximation is generically applicable but quite primitive.
|
||||
@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock):
|
||||
self._factor = self._layer_collection.make_or_get_factor(
|
||||
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
|
||||
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
vector_flat = utils.tensors_to_column(vector)
|
||||
out_flat = self._factor.left_multiply_matpower(
|
||||
vector_flat, exp, self._damping_func)
|
||||
return utils.column_to_tensors(vector, out_flat)
|
||||
|
||||
def full_fisher_block(self):
|
||||
return self._factor.get_cov()
|
||||
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._params
|
||||
|
||||
@ -452,7 +520,7 @@ class InputOutputMultiTower(object):
|
||||
return self.__outputs
|
||||
|
||||
|
||||
class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
|
||||
class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
|
||||
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
||||
|
||||
Estimates the Fisher Information matrix's diagonal entries for a fully
|
||||
@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
|
||||
|
||||
self._damping_func = _package_func(lambda: damping, (damping,))
|
||||
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
"""Multiplies the vector by the (damped) matrix-power of the block.
|
||||
|
||||
Args:
|
||||
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
|
||||
[input_size, output_size] corresponding to layer's weights. If not, a
|
||||
2-tuple of the former and a Tensor of shape [output_size] corresponding
|
||||
to the layer's bias.
|
||||
exp: A scalar representing the power to raise the block before multiplying
|
||||
it by the vector.
|
||||
|
||||
Returns:
|
||||
The vector left-multiplied by the (damped) matrix-power of the block.
|
||||
"""
|
||||
reshaped_vec = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._factor.left_multiply_matpower(
|
||||
reshaped_vec, exp, self._damping_func)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
|
||||
class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
|
||||
class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
|
||||
"""FisherBlock for 2-D convolutional layers using a diagonal approx.
|
||||
|
||||
Estimates the Fisher Information matrix's diagonal entries for a convolutional
|
||||
@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
|
||||
self._num_locations)
|
||||
self._damping_func = _package_func(damping_func, damping_id)
|
||||
|
||||
def register_matpower(self, exp):
|
||||
# Not needed for this. Matrix powers are computed on demand in the
|
||||
# diagonal case
|
||||
pass
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
reshaped_vect = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._factor.left_multiply_matpower(
|
||||
reshaped_vect, exp, self._damping_func)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
|
||||
class KroneckerProductFB(FisherBlock):
|
||||
"""A base class for blocks with separate input and output Kronecker factors.
|
||||
@ -651,9 +684,10 @@ class KroneckerProductFB(FisherBlock):
|
||||
else:
|
||||
maybe_normalized_damping = damping
|
||||
|
||||
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
|
||||
self._output_factor.get_cov(),
|
||||
maybe_normalized_damping**0.5)
|
||||
return compute_pi_adjusted_damping(
|
||||
self._input_factor.get_cov_as_linear_operator(),
|
||||
self._output_factor.get_cov_as_linear_operator(),
|
||||
maybe_normalized_damping**0.5)
|
||||
|
||||
if normalization is not None:
|
||||
damping_id = ("compute_pi_adjusted_damping",
|
||||
@ -675,6 +709,14 @@ class KroneckerProductFB(FisherBlock):
|
||||
self._input_factor.register_matpower(exp, self._input_damping_func)
|
||||
self._output_factor.register_matpower(exp, self._output_damping_func)
|
||||
|
||||
def register_cholesky(self):
|
||||
self._input_factor.register_cholesky(self._input_damping_func)
|
||||
self._output_factor.register_cholesky(self._output_damping_func)
|
||||
|
||||
def register_cholesky_inverse(self):
|
||||
self._input_factor.register_cholesky_inverse(self._input_damping_func)
|
||||
self._output_factor.register_cholesky_inverse(self._output_damping_func)
|
||||
|
||||
@property
|
||||
def _renorm_coeff(self):
|
||||
"""Kronecker factor multiplier coefficient.
|
||||
@ -687,17 +729,47 @@ class KroneckerProductFB(FisherBlock):
|
||||
"""
|
||||
return 1.0
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
def _multiply_factored_matrix(self, left_factor, right_factor, vector,
|
||||
extra_scale=1.0, transpose_left=False,
|
||||
transpose_right=False):
|
||||
reshaped_vector = utils.layer_params_to_mat2d(vector)
|
||||
reshaped_out = self._output_factor.right_multiply_matpower(
|
||||
reshaped_vector, exp, self._output_damping_func)
|
||||
reshaped_out = self._input_factor.left_multiply_matpower(
|
||||
reshaped_out, exp, self._input_damping_func)
|
||||
if self._renorm_coeff != 1.0:
|
||||
renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
|
||||
reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
|
||||
reshaped_out = right_factor.matmul_right(reshaped_vector,
|
||||
adjoint=transpose_right)
|
||||
reshaped_out = left_factor.matmul(reshaped_out,
|
||||
adjoint=transpose_left)
|
||||
if extra_scale != 1.0:
|
||||
reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
|
||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
left_factor = self._input_factor.get_matpower(
|
||||
exp, self._input_damping_func)
|
||||
right_factor = self._output_factor.get_matpower(
|
||||
exp, self._output_damping_func)
|
||||
extra_scale = float(self._renorm_coeff)**exp
|
||||
return self._multiply_factored_matrix(left_factor, right_factor, vector,
|
||||
extra_scale=extra_scale)
|
||||
|
||||
def multiply_cholesky(self, vector, transpose=False):
|
||||
left_factor = self._input_factor.get_cholesky(self._input_damping_func)
|
||||
right_factor = self._output_factor.get_cholesky(self._output_damping_func)
|
||||
extra_scale = float(self._renorm_coeff)**0.5
|
||||
return self._multiply_factored_matrix(left_factor, right_factor, vector,
|
||||
extra_scale=extra_scale,
|
||||
transpose_left=transpose,
|
||||
transpose_right=not transpose)
|
||||
|
||||
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||
left_factor = self._input_factor.get_cholesky_inverse(
|
||||
self._input_damping_func)
|
||||
right_factor = self._output_factor.get_cholesky_inverse(
|
||||
self._output_damping_func)
|
||||
extra_scale = float(self._renorm_coeff)**-0.5
|
||||
return self._multiply_factored_matrix(left_factor, right_factor, vector,
|
||||
extra_scale=extra_scale,
|
||||
transpose_left=transpose,
|
||||
transpose_right=not transpose)
|
||||
|
||||
def full_fisher_block(self):
|
||||
"""Explicitly constructs the full Fisher block.
|
||||
|
||||
@ -706,8 +778,8 @@ class KroneckerProductFB(FisherBlock):
|
||||
Returns:
|
||||
The full Fisher block.
|
||||
"""
|
||||
left_factor = self._input_factor.get_cov()
|
||||
right_factor = self._output_factor.get_cov()
|
||||
left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
|
||||
right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
|
||||
return self._renorm_coeff * utils.kronecker_product(left_factor,
|
||||
right_factor)
|
||||
|
||||
@ -796,7 +868,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
|
||||
|
||||
|
||||
class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
|
||||
"""FisherBlock for convolutional layers using the basic KFC approx.
|
||||
r"""FisherBlock for convolutional layers using the basic KFC approx.
|
||||
|
||||
Estimates the Fisher Information matrix's blog for a convolutional
|
||||
layer.
|
||||
@ -945,10 +1017,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB):
|
||||
self._filter_shape = (filter_height, filter_width, in_channels,
|
||||
in_channels * channel_multiplier)
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
def _multiply_matrix(self, matrix, vector):
|
||||
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
|
||||
conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
|
||||
conv2d_vector, exp)
|
||||
conv2d_result = super(
|
||||
DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
|
||||
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
|
||||
|
||||
|
||||
@ -1016,10 +1088,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
|
||||
self._filter_shape = (filter_height, filter_width, in_channels,
|
||||
in_channels * channel_multiplier)
|
||||
|
||||
def multiply_matpower(self, vector, exp):
|
||||
def _multiply_factored_matrix(self, left_factor, right_factor, vector,
|
||||
extra_scale=1.0, transpose_left=False,
|
||||
transpose_right=False):
|
||||
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
|
||||
conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
|
||||
conv2d_vector, exp)
|
||||
conv2d_result = super(
|
||||
DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
|
||||
left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
|
||||
transpose_left=transpose_left, transpose_right=transpose_right)
|
||||
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
|
||||
|
||||
|
||||
@ -1664,3 +1740,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
|
||||
return utils.mat2d_to_layer_params(vector, Z)
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
def multiply_cholesky(self, vector):
|
||||
raise NotImplementedError("FullyConnectedSeriesFB does not support "
|
||||
"Cholesky computations.")
|
||||
|
||||
def multiply_cholesky_inverse(self, vector):
|
||||
raise NotImplementedError("FullyConnectedSeriesFB does not support "
|
||||
"Cholesky computations.")
|
||||
|
||||
|
@ -24,6 +24,7 @@ import contextlib
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
@ -399,7 +400,7 @@ class FisherFactor(object):
|
||||
the cov update.
|
||||
|
||||
Returns:
|
||||
Tensor of same shape as self.get_cov_var().
|
||||
Tensor of same shape as self.get_cov().
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -448,78 +449,43 @@ class FisherFactor(object):
|
||||
"""Create and return update ops corresponding to registered computations."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_cov(self):
|
||||
"""Get full covariance matrix.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [n, n]. Represents all parameter-parameter correlations
|
||||
captured by this FisherFactor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_cov_var(self):
|
||||
"""Get variable backing this FisherFactor.
|
||||
|
||||
May or may not be the same as self.get_cov()
|
||||
|
||||
Returns:
|
||||
Variable of shape self._cov_shape.
|
||||
"""
|
||||
return self._cov
|
||||
|
||||
@abc.abstractmethod
|
||||
def left_multiply_matpower(self, x, exp, damping_func):
|
||||
"""Left multiplies 'x' by matrix power of this factor (w/ damping applied).
|
||||
|
||||
This calculation is essentially:
|
||||
(C + damping * I)**exp * x
|
||||
where * is matrix-multiplication, ** is matrix power, I is the identity
|
||||
matrix, and C is the matrix represented by this factor.
|
||||
|
||||
x can represent either a matrix or a vector. For some factors, 'x' might
|
||||
represent a vector but actually be stored as a 2D matrix for convenience.
|
||||
|
||||
Args:
|
||||
x: Tensor. Represents a single vector. Shape depends on implementation.
|
||||
exp: float. The matrix exponent to use.
|
||||
damping_func: A function that computes a 0-D Tensor or a float which will
|
||||
be the damping value used. i.e. damping = damping_func().
|
||||
|
||||
Returns:
|
||||
Tensor of same shape as 'x' representing the result of the multiplication.
|
||||
"""
|
||||
def get_cov_as_linear_operator(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def right_multiply_matpower(self, x, exp, damping_func):
|
||||
"""Right multiplies 'x' by matrix power of this factor (w/ damping applied).
|
||||
def register_matpower(self, exp, damping_func):
|
||||
pass
|
||||
|
||||
This calculation is essentially:
|
||||
x * (C + damping * I)**exp
|
||||
where * is matrix-multiplication, ** is matrix power, I is the identity
|
||||
matrix, and C is the matrix represented by this factor.
|
||||
@abc.abstractmethod
|
||||
def register_cholesky(self, damping_func):
|
||||
pass
|
||||
|
||||
Unlike left_multiply_matpower, x will always be a matrix.
|
||||
@abc.abstractmethod
|
||||
def register_cholesky_inverse(self, damping_func):
|
||||
pass
|
||||
|
||||
Args:
|
||||
x: Tensor. Represents a single vector. Shape depends on implementation.
|
||||
exp: float. The matrix exponent to use.
|
||||
damping_func: A function that computes a 0-D Tensor or a float which will
|
||||
be the damping value used. i.e. damping = damping_func().
|
||||
@abc.abstractmethod
|
||||
def get_matpower(self, exp, damping_func):
|
||||
pass
|
||||
|
||||
Returns:
|
||||
Tensor of same shape as 'x' representing the result of the multiplication.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def get_cholesky(self, damping_func):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_cholesky_inverse(self, damping_func):
|
||||
pass
|
||||
|
||||
|
||||
class InverseProvidingFactor(FisherFactor):
|
||||
"""Base class for FisherFactors that maintain inverses explicitly.
|
||||
class DenseSquareMatrixFactor(FisherFactor):
|
||||
"""Base class for FisherFactors that are stored as dense square matrices.
|
||||
|
||||
This class explicitly calculates and stores inverses of covariance matrices
|
||||
provided by the underlying FisherFactor implementation. It is assumed that
|
||||
vectors can be represented as 2-D matrices.
|
||||
This class explicitly calculates and stores inverses of their `cov` matrices,
|
||||
which must be square dense matrices.
|
||||
|
||||
Subclasses must implement the _compute_new_cov method, and the _var_scope and
|
||||
_cov_shape properties.
|
||||
@ -538,7 +504,19 @@ class InverseProvidingFactor(FisherFactor):
|
||||
self._eigendecomp = None
|
||||
self._damping_funcs_by_id = {} # {hashable: lambda}
|
||||
|
||||
super(InverseProvidingFactor, self).__init__()
|
||||
self._cholesky_registrations = set() # { hashable }
|
||||
self._cholesky_inverse_registrations = set() # { hashable }
|
||||
|
||||
self._cholesky_by_damping = {} # { hashable: variable }
|
||||
self._cholesky_inverse_by_damping = {} # { hashable: variable }
|
||||
|
||||
super(DenseSquareMatrixFactor, self).__init__()
|
||||
|
||||
def get_cov_as_linear_operator(self):
|
||||
assert self.get_cov().shape.ndims == 2
|
||||
return lo.LinearOperatorFullMatrix(self.get_cov(),
|
||||
is_self_adjoint=True,
|
||||
is_square=True)
|
||||
|
||||
def _register_damping(self, damping_func):
|
||||
damping_id = graph_func_to_id(damping_func)
|
||||
@ -563,8 +541,6 @@ class InverseProvidingFactor(FisherFactor):
|
||||
be the damping value used. i.e. damping = damping_func().
|
||||
"""
|
||||
if exp == 1.0:
|
||||
# We don't register these. The user shouldn't even be calling this
|
||||
# function with exp = 1.0.
|
||||
return
|
||||
|
||||
damping_id = self._register_damping(damping_func)
|
||||
@ -572,6 +548,38 @@ class InverseProvidingFactor(FisherFactor):
|
||||
if (exp, damping_id) not in self._matpower_registrations:
|
||||
self._matpower_registrations.add((exp, damping_id))
|
||||
|
||||
def register_cholesky(self, damping_func):
|
||||
"""Registers a Cholesky factor to be maintained and served on demand.
|
||||
|
||||
This creates a variable and signals make_inverse_update_ops to make the
|
||||
corresponding update op. The variable can be read via the method
|
||||
get_cholesky.
|
||||
|
||||
Args:
|
||||
damping_func: A function that computes a 0-D Tensor or a float which will
|
||||
be the damping value used. i.e. damping = damping_func().
|
||||
"""
|
||||
damping_id = self._register_damping(damping_func)
|
||||
|
||||
if damping_id not in self._cholesky_registrations:
|
||||
self._cholesky_registrations.add(damping_id)
|
||||
|
||||
def register_cholesky_inverse(self, damping_func):
|
||||
"""Registers an inverse Cholesky factor to be maintained/served on demand.
|
||||
|
||||
This creates a variable and signals make_inverse_update_ops to make the
|
||||
corresponding update op. The variable can be read via the method
|
||||
get_cholesky_inverse.
|
||||
|
||||
Args:
|
||||
damping_func: A function that computes a 0-D Tensor or a float which will
|
||||
be the damping value used. i.e. damping = damping_func().
|
||||
"""
|
||||
damping_id = self._register_damping(damping_func)
|
||||
|
||||
if damping_id not in self._cholesky_inverse_registrations:
|
||||
self._cholesky_inverse_registrations.add(damping_id)
|
||||
|
||||
def instantiate_inv_variables(self):
|
||||
"""Makes the internal "inverse" variable(s)."""
|
||||
|
||||
@ -589,6 +597,32 @@ class InverseProvidingFactor(FisherFactor):
|
||||
assert (exp, damping_id) not in self._matpower_by_exp_and_damping
|
||||
self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
|
||||
|
||||
for damping_id in self._cholesky_registrations:
|
||||
damping_func = self._damping_funcs_by_id[damping_id]
|
||||
damping_string = graph_func_to_string(damping_func)
|
||||
with variable_scope.variable_scope(self._var_scope):
|
||||
chol = variable_scope.get_variable(
|
||||
"cholesky_damp{}".format(damping_string),
|
||||
initializer=inverse_initializer,
|
||||
shape=self._cov_shape,
|
||||
trainable=False,
|
||||
dtype=self._dtype)
|
||||
assert damping_id not in self._cholesky_by_damping
|
||||
self._cholesky_by_damping[damping_id] = chol
|
||||
|
||||
for damping_id in self._cholesky_inverse_registrations:
|
||||
damping_func = self._damping_funcs_by_id[damping_id]
|
||||
damping_string = graph_func_to_string(damping_func)
|
||||
with variable_scope.variable_scope(self._var_scope):
|
||||
cholinv = variable_scope.get_variable(
|
||||
"cholesky_inverse_damp{}".format(damping_string),
|
||||
initializer=inverse_initializer,
|
||||
shape=self._cov_shape,
|
||||
trainable=False,
|
||||
dtype=self._dtype)
|
||||
assert damping_id not in self._cholesky_inverse_by_damping
|
||||
self._cholesky_inverse_by_damping[damping_id] = cholinv
|
||||
|
||||
def make_inverse_update_ops(self):
|
||||
"""Create and return update ops corresponding to registered computations."""
|
||||
ops = []
|
||||
@ -606,7 +640,8 @@ class InverseProvidingFactor(FisherFactor):
|
||||
|
||||
# We precompute these so we don't need to evaluate them multiple times (for
|
||||
# each matrix power that uses them)
|
||||
damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]()
|
||||
damping_value_by_id = {damping_id: math_ops.cast(
|
||||
self._damping_funcs_by_id[damping_id](), self._dtype)
|
||||
for damping_id in self._damping_funcs_by_id}
|
||||
|
||||
if use_eig:
|
||||
@ -627,29 +662,91 @@ class InverseProvidingFactor(FisherFactor):
|
||||
self._matpower_by_exp_and_damping.items()):
|
||||
assert exp == -1
|
||||
damping = damping_value_by_id[damping_id]
|
||||
ops.append(matpower.assign(utils.posdef_inv(self._cov, damping)))
|
||||
ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
|
||||
|
||||
# TODO(b/77902055): If inverses are being computed with Cholesky's
|
||||
# we can share the work. Instead this code currently just computes the
|
||||
# Cholesky a second time. It does at least share work between requests for
|
||||
# Cholesky's and Cholesky inverses with the same damping id.
|
||||
for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
|
||||
cholesky_ops = []
|
||||
|
||||
damping = damping_value_by_id[damping_id]
|
||||
cholesky_value = utils.cholesky(self.get_cov(), damping)
|
||||
|
||||
if damping_id in self._cholesky_by_damping:
|
||||
cholesky = self._cholesky_by_damping[damping_id]
|
||||
cholesky_ops.append(cholesky.assign(cholesky_value))
|
||||
|
||||
identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
|
||||
dtype=cholesky_value.dtype)
|
||||
cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
|
||||
identity)
|
||||
cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
|
||||
|
||||
ops.append(control_flow_ops.group(*cholesky_ops))
|
||||
|
||||
for damping_id, cholesky in self._cholesky_by_damping.items():
|
||||
if damping_id not in self._cholesky_inverse_by_damping:
|
||||
damping = damping_value_by_id[damping_id]
|
||||
cholesky_value = utils.cholesky(self.get_cov(), damping)
|
||||
ops.append(cholesky.assign(cholesky_value))
|
||||
|
||||
self._eigendecomp = False
|
||||
return ops
|
||||
|
||||
def get_inverse(self, damping_func):
|
||||
# Just for backwards compatibility of some old code and tests
|
||||
damping_id = graph_func_to_id(damping_func)
|
||||
return self._matpower_by_exp_and_damping[(-1, damping_id)]
|
||||
return self.get_matpower(-1, damping_func)
|
||||
|
||||
def get_matpower(self, exp, damping_func):
|
||||
# Note that this function returns a variable which gets updated by the
|
||||
# inverse ops. It may be stale / inconsistent with the latest value of
|
||||
# get_cov().
|
||||
if exp != 1:
|
||||
damping_id = graph_func_to_id(damping_func)
|
||||
matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
|
||||
else:
|
||||
matpower = self.get_cov()
|
||||
identity = linalg_ops.eye(matpower.shape.as_list()[0],
|
||||
dtype=matpower.dtype)
|
||||
matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
|
||||
|
||||
assert matpower.shape.ndims == 2
|
||||
return lo.LinearOperatorFullMatrix(matpower,
|
||||
is_non_singular=True,
|
||||
is_self_adjoint=True,
|
||||
is_positive_definite=True,
|
||||
is_square=True)
|
||||
|
||||
def get_cholesky(self, damping_func):
|
||||
# Note that this function returns a variable which gets updated by the
|
||||
# inverse ops. It may be stale / inconsistent with the latest value of
|
||||
# get_cov().
|
||||
damping_id = graph_func_to_id(damping_func)
|
||||
return self._matpower_by_exp_and_damping[(exp, damping_id)]
|
||||
cholesky = self._cholesky_by_damping[damping_id]
|
||||
assert cholesky.shape.ndims == 2
|
||||
return lo.LinearOperatorFullMatrix(cholesky,
|
||||
is_non_singular=True,
|
||||
is_square=True)
|
||||
|
||||
def get_cholesky_inverse(self, damping_func):
|
||||
# Note that this function returns a variable which gets updated by the
|
||||
# inverse ops. It may be stale / inconsistent with the latest value of
|
||||
# get_cov().
|
||||
damping_id = graph_func_to_id(damping_func)
|
||||
cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
|
||||
assert cholesky_inv.shape.ndims == 2
|
||||
return lo.LinearOperatorFullMatrix(cholesky_inv,
|
||||
is_non_singular=True,
|
||||
is_square=True)
|
||||
|
||||
def get_eigendecomp(self):
|
||||
"""Creates or retrieves eigendecomposition of self._cov."""
|
||||
# Unlike get_matpower this doesn't retrieve a stored variable, but instead
|
||||
# always computes a fresh version from the current value of get_cov().
|
||||
if not self._eigendecomp:
|
||||
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
|
||||
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
|
||||
|
||||
# The matrix self._cov is positive semidefinite by construction, but the
|
||||
# numerical eigenvalues could be negative due to numerical errors, so here
|
||||
@ -660,45 +757,8 @@ class InverseProvidingFactor(FisherFactor):
|
||||
|
||||
return self._eigendecomp
|
||||
|
||||
def get_cov(self):
|
||||
# Variable contains full covariance matrix.
|
||||
return self.get_cov_var()
|
||||
|
||||
def left_multiply_matpower(self, x, exp, damping_func):
|
||||
if isinstance(x, tf_ops.IndexedSlices):
|
||||
raise ValueError("Left-multiply not yet supported for IndexedSlices.")
|
||||
|
||||
if x.shape.ndims != 2:
|
||||
raise ValueError(
|
||||
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
|
||||
% (x,))
|
||||
|
||||
if exp == 1:
|
||||
return math_ops.matmul(self.get_cov(), x) + damping_func() * x
|
||||
|
||||
return math_ops.matmul(self.get_matpower(exp, damping_func), x)
|
||||
|
||||
def right_multiply_matpower(self, x, exp, damping_func):
|
||||
if isinstance(x, tf_ops.IndexedSlices):
|
||||
if exp == 1:
|
||||
n = self.get_cov().shape[0]
|
||||
damped_cov = self.get_cov() + damping_func() * array_ops.eye(n)
|
||||
return utils.matmul_sparse_dense(x, damped_cov)
|
||||
|
||||
return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func))
|
||||
|
||||
if x.shape.ndims != 2:
|
||||
raise ValueError(
|
||||
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
|
||||
% (x,))
|
||||
|
||||
if exp == 1:
|
||||
return math_ops.matmul(x, self.get_cov()) + damping_func() * x
|
||||
|
||||
return math_ops.matmul(x, self.get_matpower(exp, damping_func))
|
||||
|
||||
|
||||
class FullFactor(InverseProvidingFactor):
|
||||
class FullFactor(DenseSquareMatrixFactor):
|
||||
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
|
||||
|
||||
Note that this uses the naive "square the sum estimator", and so is applicable
|
||||
@ -757,42 +817,52 @@ class DiagonalFactor(FisherFactor):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._damping_funcs_by_id = {} # { hashable: lambda }
|
||||
super(DiagonalFactor, self).__init__()
|
||||
|
||||
def get_cov_as_linear_operator(self):
|
||||
assert self._matrix_diagonal.shape.ndims == 1
|
||||
return lo.LinearOperatorDiag(self._matrix_diagonal,
|
||||
is_self_adjoint=True,
|
||||
is_square=True)
|
||||
|
||||
@property
|
||||
def _cov_initializer(self):
|
||||
return diagonal_covariance_initializer
|
||||
|
||||
@property
|
||||
def _matrix_diagonal(self):
|
||||
return array_ops.reshape(self.get_cov(), [-1])
|
||||
|
||||
def make_inverse_update_ops(self):
|
||||
return []
|
||||
|
||||
def instantiate_inv_variables(self):
|
||||
pass
|
||||
|
||||
def get_cov(self):
|
||||
# self.get_cov() could be any shape, but it must have one entry per
|
||||
# parameter. Flatten it into a vector.
|
||||
cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
|
||||
return array_ops.diag(cov_diag_vec)
|
||||
|
||||
def left_multiply_matpower(self, x, exp, damping_func):
|
||||
matpower = (self.get_cov_var() + damping_func())**exp
|
||||
|
||||
if isinstance(x, tf_ops.IndexedSlices):
|
||||
return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x)
|
||||
|
||||
if x.shape != matpower.shape:
|
||||
raise ValueError("x (%s) and cov (%s) must have same shape." %
|
||||
(x, matpower))
|
||||
return matpower * x
|
||||
|
||||
def right_multiply_matpower(self, x, exp, damping_func):
|
||||
raise NotImplementedError("Only left-multiply is currently supported.")
|
||||
|
||||
def register_matpower(self, exp, damping_func):
|
||||
pass
|
||||
|
||||
def register_cholesky(self, damping_func):
|
||||
pass
|
||||
|
||||
def register_cholesky_inverse(self, damping_func):
|
||||
pass
|
||||
|
||||
def get_matpower(self, exp, damping_func):
|
||||
matpower_diagonal = (self._matrix_diagonal
|
||||
+ math_ops.cast(damping_func(), self._dtype))**exp
|
||||
return lo.LinearOperatorDiag(matpower_diagonal,
|
||||
is_non_singular=True,
|
||||
is_self_adjoint=True,
|
||||
is_positive_definite=True,
|
||||
is_square=True)
|
||||
|
||||
def get_cholesky(self, damping_func):
|
||||
return self.get_matpower(0.5, damping_func)
|
||||
|
||||
def get_cholesky_inverse(self, damping_func):
|
||||
return self.get_matpower(-0.5, damping_func)
|
||||
|
||||
|
||||
class NaiveDiagonalFactor(DiagonalFactor):
|
||||
"""FisherFactor for a diagonal approximation of any type of param's Fisher.
|
||||
@ -1167,7 +1237,7 @@ class ConvDiagonalFactor(DiagonalFactor):
|
||||
return self._inputs[tower].device
|
||||
|
||||
|
||||
class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
||||
class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
|
||||
"""Kronecker factor for the input or output side of a fully-connected layer.
|
||||
"""
|
||||
|
||||
@ -1220,7 +1290,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
||||
return self._tensors[0][tower].device
|
||||
|
||||
|
||||
class ConvInputKroneckerFactor(InverseProvidingFactor):
|
||||
class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
|
||||
r"""Kronecker factor for the input side of a convolutional layer.
|
||||
|
||||
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
|
||||
@ -1384,7 +1454,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
|
||||
return self._inputs[tower].device
|
||||
|
||||
|
||||
class ConvOutputKroneckerFactor(InverseProvidingFactor):
|
||||
class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
|
||||
r"""Kronecker factor for the output side of a convolutional layer.
|
||||
|
||||
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
|
||||
@ -1674,6 +1744,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
|
||||
psi_var) in self._option1quants_by_damping.items():
|
||||
|
||||
damping = self._damping_funcs_by_id[damping_id]()
|
||||
damping = math_ops.cast(damping, self._dtype)
|
||||
|
||||
invsqrtC0 = math_ops.matmul(
|
||||
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
|
||||
@ -1702,6 +1773,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
|
||||
mu_var) in self._option2quants_by_damping.items():
|
||||
|
||||
damping = self._damping_funcs_by_id[damping_id]()
|
||||
damping = math_ops.cast(damping, self._dtype)
|
||||
|
||||
# compute C0^(-1/2)
|
||||
invsqrtC0 = math_ops.matmul(
|
||||
|
95
tensorflow/contrib/kfac/python/ops/linear_operator.py
Normal file
95
tensorflow/contrib/kfac/python/ops/linear_operator.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""SmartMatrices definitions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linalg
|
||||
from tensorflow.python.ops.linalg import linalg_impl
|
||||
from tensorflow.python.ops.linalg import linear_operator_util as lou
|
||||
|
||||
|
||||
class LinearOperatorExtras(object): # pylint: disable=missing-docstring
|
||||
|
||||
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||
|
||||
with self._name_scope(name, values=[x]):
|
||||
if isinstance(x, ops.IndexedSlices):
|
||||
return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
|
||||
self_dim = -2 if adjoint else -1
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||
|
||||
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||
|
||||
with self._name_scope(name, values=[x]):
|
||||
|
||||
if isinstance(x, ops.IndexedSlices):
|
||||
return self._matmul_right_sparse(
|
||||
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
|
||||
self_dim = -1 if adjoint else -2
|
||||
arg_dim = -2 if adjoint_arg else -1
|
||||
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||
|
||||
return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
|
||||
class LinearOperatorFullMatrix(LinearOperatorExtras,
|
||||
linalg.LinearOperatorFullMatrix):
|
||||
|
||||
# TODO(b/78117889) Remove this definition once core LinearOperator
|
||||
# has _matmul_right.
|
||||
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
|
||||
return lou.matmul_with_broadcast(
|
||||
x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
|
||||
|
||||
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
assert not adjoint and not adjoint_arg
|
||||
return utils.matmul_sparse_dense(x, self._matrix)
|
||||
|
||||
|
||||
class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
|
||||
linalg.LinearOperatorDiag):
|
||||
|
||||
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
|
||||
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
|
||||
x = linalg_impl.adjoint(x) if adjoint_arg else x
|
||||
return diag_mat * x
|
||||
|
||||
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
|
||||
assert not adjoint_arg
|
||||
return utils.matmul_diag_sparse(diag_mat, x)
|
||||
|
||||
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
raise NotImplementedError
|
@ -35,7 +35,7 @@ def _make_thunk_on_device(func, device):
|
||||
class RoundRobinPlacementMixin(object):
|
||||
"""Implements round robin placement strategy for ops and variables."""
|
||||
|
||||
def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs):
|
||||
def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
|
||||
"""Initializes the RoundRobinPlacementMixin class.
|
||||
|
||||
Args:
|
||||
@ -45,11 +45,10 @@ class RoundRobinPlacementMixin(object):
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
*args:
|
||||
**kwargs:
|
||||
**kwargs: Need something here?
|
||||
|
||||
"""
|
||||
super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs)
|
||||
super(RoundRobinPlacementMixin, self).__init__(**kwargs)
|
||||
self._cov_devices = cov_devices
|
||||
self._inv_devices = inv_devices
|
||||
|
||||
|
@ -235,6 +235,13 @@ posdef_eig_functions = {
|
||||
}
|
||||
|
||||
|
||||
def cholesky(tensor, damping):
|
||||
"""Computes the inverse of tensor + damping * identity."""
|
||||
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
|
||||
damping = math_ops.cast(damping, dtype=tensor.dtype)
|
||||
return linalg_ops.cholesky(tensor + damping * identity)
|
||||
|
||||
|
||||
class SubGraph(object):
|
||||
"""Defines a subgraph given by all the dependencies of a given set of outputs.
|
||||
"""
|
||||
@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format):
|
||||
return data_format.endswith("C")
|
||||
|
||||
|
||||
def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
|
||||
def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
|
||||
"""Computes matmul(A, B) where A is sparse, B is dense.
|
||||
|
||||
Args:
|
||||
A: tf.IndexedSlices with dense shape [m, n].
|
||||
B: tf.Tensor with shape [n, k].
|
||||
name: str. Name of op.
|
||||
transpose_a: Bool. If true we transpose A before multiplying it by B.
|
||||
(Default: False)
|
||||
transpose_b: Bool. If true we transpose B before multiplying it by A.
|
||||
(Default: False)
|
||||
|
||||
Returns:
|
||||
tf.IndexedSlices resulting from matmul(A, B).
|
||||
@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
|
||||
raise ValueError("A must represent a matrix. Found: %s." % A)
|
||||
if B.shape.ndims != 2:
|
||||
raise ValueError("B must be a matrix.")
|
||||
new_values = math_ops.matmul(A.values, B)
|
||||
new_values = math_ops.matmul(
|
||||
A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
|
||||
return ops.IndexedSlices(
|
||||
new_values,
|
||||
A.indices,
|
||||
|
Loading…
x
Reference in New Issue
Block a user