diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 97184dabb05..0626875b763 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -132,6 +132,7 @@ py_library( ":clip_weights", ":conditioning_utils", ":random_tensor_pool", + ":spectral_normalization", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -676,3 +677,45 @@ py_test( "//third_party/py/numpy", ], ) + +py_library( + name = "spectral_normalization", + srcs = [ + "python/features/python/spectral_normalization.py", + "python/features/python/spectral_normalization_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:standard_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/keras:engine", + ], +) + +py_test( + name = "spectral_normalization_test", + srcs = ["python/features/python/spectral_normalization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":spectral_normalization", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 4816daf7601..410c3a02052 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -27,11 +27,13 @@ from __future__ import print_function from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils from tensorflow.contrib.gan.python.features.python import random_tensor_pool +from tensorflow.contrib.gan.python.features.python import spectral_normalization from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * +from tensorflow.contrib.gan.python.features.python.spectral_normalization import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -40,5 +42,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ _allowed_symbols += random_tensor_pool.__all__ +_allowed_symbols += spectral_normalization.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py new file mode 100644 index 00000000000..54d3d0a218d --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = spectral_normalization_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py new file mode 100644 index 00000000000..0cc653f0a79 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -0,0 +1,315 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import numbers +import re + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + +__all__ = [ + 'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer', + 'spectral_normalization_custom_getter', 'keras_spectral_normalization' +] + +# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then +# can't directly be assigned back to the tf.bfloat16 variable. +_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64) +_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u' + + +def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): + """Estimates the largest singular value in the weight tensor. + + Args: + w_tensor: The weight matrix whose spectral norm should be computed. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + The largest singular value (the spectral norm) of w. + """ + with variable_scope.variable_scope(name, 'spectral_norm'): + # The paper says to flatten convnet kernel weights from + # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D + # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to + # (KH * KW * C_in, C_out), and similarly for other layers that put output + # channels as last dimension. + # n.b. this means that w here is equivalent to w.T in the paper. + w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1])) + + # Persisted approximation of first left singular vector of matrix `w`. + u_var = variable_scope.get_variable( + _PERSISTED_U_VARIABLE_SUFFIX, + shape=(w.shape[0], 1), + dtype=w.dtype, + initializer=init_ops.random_normal_initializer(), + trainable=False) + u = u_var + + # Use power iteration method to approximate spectral norm. + for _ in range(power_iteration_rounds): + # `v` approximates the first right singular vector of matrix `w`. + v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u)) + u = nn.l2_normalize(math_ops.matmul(w, v)) + + # Update persisted approximation. + with ops.control_dependencies([u_var.assign(u, name='update_u')]): + u = array_ops.identity(u) + + u = array_ops.stop_gradient(u) + v = array_ops.stop_gradient(v) + + # Largest singular value of `w`. + spectral_norm = math_ops.matmul( + math_ops.matmul(array_ops.transpose(u), w), v) + spectral_norm.shape.assert_is_fully_defined() + spectral_norm.shape.assert_is_compatible_with([1, 1]) + + return spectral_norm[0][0] + + +def spectral_normalize(w, power_iteration_rounds=1, name=None): + """Normalizes a weight matrix by its spectral norm. + + Args: + w: The weight matrix to be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + A normalized weight matrix tensor. + """ + with variable_scope.variable_scope(name, 'spectral_normalize'): + w_normalized = w / compute_spectral_norm( + w, power_iteration_rounds=power_iteration_rounds) + return array_ops.reshape(w_normalized, w.get_shape()) + + +def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): + """Returns a functions that can be used to apply spectral norm regularization. + + Small spectral norms enforce a small Lipschitz constant, which is necessary + for Wasserstein GANs. + + Args: + scale: A scalar multiplier. 0.0 disables the regularizer. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + scope: An optional scope name. + + Returns: + A function with the signature `sn(weights)` that applies spectral norm + regularization. + + Raises: + ValueError: If scale is negative or if scale is not a float. + """ + if isinstance(scale, numbers.Integral): + raise ValueError('scale cannot be an integer: %s' % scale) + if isinstance(scale, numbers.Real): + if scale < 0.0: + raise ValueError( + 'Setting a scale less than 0 on a regularizer: %g' % scale) + if scale == 0.0: + logging.info('Scale of 0 disables regularizer.') + return lambda _: None + + def sn(weights, name=None): + """Applies spectral norm regularization to weights.""" + with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name: + scale_t = ops.convert_to_tensor( + scale, dtype=weights.dtype.base_dtype, name='scale') + return math_ops.multiply( + scale_t, + compute_spectral_norm( + weights, power_iteration_rounds=power_iteration_rounds), + name=name) + + return sn + + +def _default_name_filter(name): + """A filter function to identify common names of weight variables. + + Args: + name: The variable name. + + Returns: + Whether `name` is a standard name for a weight/kernel variables used in the + Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries. + """ + match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name) + return match is not None + + +def spectral_normalization_custom_getter(name_filter=_default_name_filter, + power_iteration_rounds=1): + """Custom getter that performs Spectral Normalization on a weight tensor. + + Specifically it divides the weight tensor by its largest singular value. This + is intended to stabilize GAN training, by making the discriminator satisfy a + local 1-Lipschitz constraint. + + Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]. + + [sn-gan]: https://openreview.net/forum?id=B1QRgziT- + + To reproduce an SN-GAN, apply this custom_getter to every weight tensor of + your discriminator. The last dimension of the weight tensor must be the number + of output channels. + + Apply this to layers by supplying this as the `custom_getter` of a + `tf.variable_scope`. For example: + + with tf.variable_scope('discriminator', + custom_getter=spectral_norm_getter()): + net = discriminator_fn(net) + + IMPORTANT: Keras does not respect the custom_getter supplied by the + VariableScope, so Keras users should use `keras_spectral_normalization` + instead of (or in addition to) this approach. + + It is important to carefully select to which weights you want to apply + Spectral Normalization. In general you want to normalize the kernels of + convolution and dense layers, but you do not want to normalize biases. You + also want to avoid normalizing batch normalization (and similar) variables, + but in general such layers play poorly with Spectral Normalization, since the + gamma can cancel out the normalization in other layers. By default we supply a + filter that matches the kernel variable names of the dense and convolution + layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim + libraries. If you are using anything else you'll need a custom `name_filter`. + + This custom getter internally creates a variable used to compute the spectral + norm by power iteration. It will update every time the variable is accessed, + which means the normalized discriminator weights may change slightly whilst + training the generator. Whilst unusual, this matches how the paper's authors + implement it, and in general additional rounds of power iteration can't hurt. + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Returns: + A custom getter function that applies Spectral Normalization to all + Variables whose names match `name_filter`. + + Raises: + ValueError: If name_filter is not callable. + """ + if not callable(name_filter): + raise ValueError('name_filter must be callable') + + def _internal_getter(getter, name, *args, **kwargs): + """A custom getter function that applies Spectral Normalization. + + Args: + getter: The true getter to call. + name: Name of new/existing variable, in the same format as + tf.get_variable. + *args: Other positional arguments, in the same format as tf.get_variable. + **kwargs: Keyword arguments, in the same format as tf.get_variable. + + Returns: + The return value of `getter(name, *args, **kwargs)`, spectrally + normalized. + + Raises: + ValueError: If used incorrectly, or if `dtype` is not supported. + """ + if not name_filter(name): + return getter(name, *args, **kwargs) + + if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX): + raise ValueError( + 'Cannot apply Spectral Normalization to internal variables created ' + 'for Spectral Normalization. Tried to normalized variable [%s]' % + name) + + if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM: + raise ValueError('Disallowed data type {}'.format(kwargs['dtype'])) + + # This layer's weight Variable/PartitionedVariable. + w_tensor = getter(name, *args, **kwargs) + + if len(w_tensor.get_shape()) < 2: + raise ValueError( + 'Spectral norm can only be applied to multi-dimensional tensors') + + return spectral_normalize( + w_tensor, + power_iteration_rounds=power_iteration_rounds, + name=(name + '/spectral_normalize')) + + return _internal_getter + + +@contextlib.contextmanager +def keras_spectral_normalization(name_filter=_default_name_filter, + power_iteration_rounds=1): + """A context manager that enables Spectral Normalization for Keras. + + Keras doesn't respect the `custom_getter` in the VariableScope, so this is a + bit of a hack to make things work. + + Usage: + with keras_spectral_normalization(): + net = discriminator_fn(net) + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Yields: + A context manager that wraps the standard Keras variable creation method + with the `spectral_normalization_custom_getter`. + """ + original_make_variable = keras_base_layer_utils.make_variable + sn_getter = spectral_normalization_custom_getter( + name_filter=name_filter, power_iteration_rounds=power_iteration_rounds) + + def make_variable_wrapper(name, *args, **kwargs): + return sn_getter(original_make_variable, name, *args, **kwargs) + + keras_base_layer_utils.make_variable = make_variable_wrapper + + yield + + keras_base_layer_utils.make_variable = original_make_variable diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py new file mode 100644 index 00000000000..4ea21f70ec0 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py @@ -0,0 +1,354 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for features.spectral_normalization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import slim +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization +from tensorflow.contrib.layers.python.layers import layers as contrib_layers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers import convolutional as keras_convolutional +from tensorflow.python.keras.layers import core as keras_core +from tensorflow.python.layers import convolutional as layers_convolutional +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SpectralNormalizationTest(test.TestCase): + + def testComputeSpectralNorm(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + s = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False) + true_sn = s[..., 0] + estimated_sn = spectral_normalization.compute_spectral_norm(weights) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + np_true_sn = sess.run(true_sn) + for i in range(50): + est = sess.run(estimated_sn) + if i < 1: + np_est_1 = est + if i < 4: + np_est_5 = est + if i < 9: + np_est_10 = est + np_est_50 = est + + # Check that the estimate improves with more iterations. + self.assertAlmostEqual(np_true_sn, np_est_50, 0) + self.assertGreater( + abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50)) + self.assertGreater( + abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10)) + self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5)) + + def testSpectralNormalize(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + normalized_weights = spectral_normalization.spectral_normalize( + weights, power_iteration_rounds=1) + + unnormalized_sigma = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + normalized_sigma = linalg_ops.svd( + array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + s0 = sess.run(unnormalized_sigma) + + for i in range(50): + sigma = sess.run(normalized_sigma) + if i < 1: + s1 = sigma + if i < 5: + s5 = sigma + if i < 10: + s10 = sigma + s50 = sigma + + self.assertAlmostEqual(1., s50, 0) + self.assertGreater(abs(s10 - 1.), abs(s50 - 1.)) + self.assertGreater(abs(s5 - 1.), abs(s10 - 1.)) + self.assertGreater(abs(s1 - 1.), abs(s5 - 1.)) + self.assertGreater(abs(s0 - 1.), abs(s1 - 1.)) + + def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False): + x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3]) + + w_initial = np.random.randn(*w_shape) * 10 + w_initializer = init_ops.constant_initializer(w_initial) + b_initial = np.random.randn(*b_shape) + b_initializer = init_ops.constant_initializer(b_initial) + + if is_keras: + context_manager = spectral_normalization.keras_spectral_normalization() + else: + getter = spectral_normalization.spectral_normalization_custom_getter() + context_manager = variable_scope.variable_scope('', custom_getter=getter) + + with context_manager: + (net, + expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn( + x, w_initializer, b_initializer) + + x_data = np.random.rand(*x.shape) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + + # Before running a forward pass we still expect the variables values to + # differ from the initial value because of the normalizer. + w_befores = [] + for name, var in expected_normalized_vars.items(): + w_before = sess.run(var) + w_befores.append(w_before) + self.assertFalse( + np.allclose(w_initial, w_before), + msg=('%s appears not to be normalized. Before: %s After: %s' % + (name, w_initial, w_before))) + + # Not true for the unnormalized variables. + for name, var in expected_not_normalized_vars.items(): + b_before = sess.run(var) + self.assertTrue( + np.allclose(b_initial, b_before), + msg=('%s appears to be unexpectedly normalized. ' + 'Before: %s After: %s' % (name, b_initial, b_before))) + + # Run a bunch of forward passes. + for _ in range(1000): + _ = sess.run(net, feed_dict={x: x_data}) + + # We expect this to have improved the estimate of the spectral norm, + # which should have changed the variable values and brought them close + # to the true Spectral Normalized values. + _, s, _ = np.linalg.svd(w_initial.reshape([-1, 3])) + exactly_normalized = w_initial / s[0] + for w_before, (name, var) in zip(w_befores, + expected_normalized_vars.items()): + w_after = sess.run(var) + self.assertFalse( + np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8), + msg=('%s did not improve over many iterations. ' + 'Before: %s After: %s' % (name, w_before, w_after))) + self.assertAllClose( + exactly_normalized, + w_after, + rtol=1e-4, + atol=1e-4, + msg=('Estimate of spectral norm for %s was innacurate. ' + 'Normalized matrices do not match.' + 'Estimate: %s Actual: %s' % (name, w_after, + exactly_normalized))) + + def testConv2D_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = layers_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_CONV2D_BIASES'] + } + net = contrib_layers.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.conv2d.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.conv2d.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_CONV2D_WEIGHTS'], + 'biases': ['SLIM_CONV2D_BIASES'] + } + net = slim.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]} + expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = keras_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True) + + def testFC_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = layers_core.Flatten()(x) + layer = layers_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_FC_BIASES'] + } + x = contrib_layers.flatten(x) + net = contrib_layers.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.fully_connected.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_FC_WEIGHTS'], + 'biases': ['SLIM_FC_BIASES'] + } + x = slim.flatten(x) + net = slim.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'slim.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = keras_core.Flatten()(x) + layer = keras_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True) + + +if __name__ == '__main__': + test.main()