diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 1da811dc0a3..432937d8032 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -282,6 +282,73 @@ class LayerCollectionTest(test.TestCase): single_loss = sess.run(lc.total_loss()) self.assertAlmostEqual(7.6983433, single_loss) + def testRegisterFullyConnectedReuse(self): + """Ensure the 'reuse' keyword argument function as intended.""" + with ops.Graph().as_default(): + inputs = [ + array_ops.ones([2, 10]), # + array_ops.zeros([5, 10]) + ] + outputs = [ + array_ops.zeros([2, 5]), # + array_ops.ones([5, 5]) + ] + params = ( + variable_scope.get_variable('w', [10, 5]), # + variable_scope.get_variable('b', [5])) + + # Fails on second if reuse=False. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + with self.assertRaises(ValueError): + lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False) + + # Succeeds on second if reuse=True. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True) + + # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + with self.assertRaises(ValueError): + lc.register_fully_connected( + params, + inputs[1], + outputs[1], + reuse=layer_collection.VARIABLE_SCOPE) + + # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. + lc = layer_collection.LayerCollection() + lc.register_fully_connected(params, inputs[0], outputs[0]) + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + lc.register_fully_connected( + params, + inputs[1], + outputs[1], + reuse=layer_collection.VARIABLE_SCOPE) + + # Fails if block type changes. + lc = layer_collection.LayerCollection() + lc.register_fully_connected( + params, + inputs[0], + outputs[0], + approx=layer_collection.APPROX_KRONECKER_NAME) + with self.assertRaises(ValueError): + lc.register_fully_connected( + params, + inputs[1], + outputs[1], + approx=layer_collection.APPROX_DIAGONAL_NAME, + reuse=True) + + # Fails if reuse requested but no FisherBlock exists. + lc = layer_collection.LayerCollection() + with self.assertRaises(KeyError): + lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True) + def testMakeOrGetFactor(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 49279954dc8..cd711d05610 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -39,10 +39,15 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +# Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" +# Possible value for 'reuse' keyword argument. Sets 'reuse' to +# tf.get_variable_scope().reuse. +VARIABLE_SCOPE = "VARIABLE_SCOPE" + # TODO(jamesmartens): need to add find_canonical_output back into this somewhere @@ -254,19 +259,58 @@ class LayerCollection(object): params, inputs, outputs, - approx=APPROX_KRONECKER_NAME): - has_bias = isinstance(params, (tuple, list)) - if approx == APPROX_KRONECKER_NAME: - block = fb.FullyConnectedKFACBasicFB(self, has_bias) - block.register_additional_minibatch(inputs, outputs) - self.register_block(params, block) - elif approx == APPROX_DIAGONAL_NAME: - block = fb.FullyConnectedDiagonalFB(self, has_bias) - block.register_additional_minibatch(inputs, outputs) - self.register_block(params, block) - else: + approx=APPROX_KRONECKER_NAME, + reuse=VARIABLE_SCOPE): + """Registers a fully connnected layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_size]. Preactivations + produced by layer. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, + } + + if approx not in approx_to_block_types: raise ValueError("Bad value {} for approx.".format(approx)) + block_type = approx_to_block_types[approx] + has_bias = isinstance(params, (tuple, list)) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + block = self.fisher_blocks.get(params, None) + if block is None: + raise KeyError( + "Reuse requested but no FisherBlock found for params {}.".format( + params)) + if not isinstance(block, block_type): + raise ValueError( + "Requested block of type {} but block of type {} already exists " + "for params {}.".format(block_type, type(block), params)) + + else: + block = block_type(self, has_bias) + self.register_block(params, block) + + block.register_additional_minibatch(inputs, outputs) + def register_conv2d(self, params, strides, padding, inputs, outputs, approx=APPROX_KRONECKER_NAME): diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index 63a9b173bc8..d6bf61a2102 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -35,6 +35,7 @@ _allowed_symbols = [ "APPROX_KRONECKER_NAME", "APPROX_DIAGONAL_NAME", "APPROX_FULL_NAME", + "VARIABLE_SCOPE", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)