Spectral Normalization implementation in TFGAN
PiperOrigin-RevId: 227725854
This commit is contained in:
parent
65011487c4
commit
e518527f10
@ -132,6 +132,7 @@ py_library(
|
||||
":clip_weights",
|
||||
":conditioning_utils",
|
||||
":random_tensor_pool",
|
||||
":spectral_normalization",
|
||||
":virtual_batchnorm",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
@ -676,3 +677,45 @@ py_test(
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "spectral_normalization",
|
||||
srcs = [
|
||||
"python/features/python/spectral_normalization.py",
|
||||
"python/features/python/spectral_normalization_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:standard_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/keras:engine",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "spectral_normalization_test",
|
||||
srcs = ["python/features/python/spectral_normalization_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":spectral_normalization",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/slim",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/keras:layers",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -27,11 +27,13 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.gan.python.features.python import clip_weights
|
||||
from tensorflow.contrib.gan.python.features.python import conditioning_utils
|
||||
from tensorflow.contrib.gan.python.features.python import random_tensor_pool
|
||||
from tensorflow.contrib.gan.python.features.python import spectral_normalization
|
||||
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm
|
||||
|
||||
from tensorflow.contrib.gan.python.features.python.clip_weights import *
|
||||
from tensorflow.contrib.gan.python.features.python.conditioning_utils import *
|
||||
from tensorflow.contrib.gan.python.features.python.random_tensor_pool import *
|
||||
from tensorflow.contrib.gan.python.features.python.spectral_normalization import *
|
||||
from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
@ -40,5 +42,6 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
_allowed_symbols = clip_weights.__all__
|
||||
_allowed_symbols += conditioning_utils.__all__
|
||||
_allowed_symbols += random_tensor_pool.__all__
|
||||
_allowed_symbols += spectral_normalization.__all__
|
||||
_allowed_symbols += virtual_batchnorm.__all__
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -0,0 +1,32 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-like layers and utilities that implement Spectral Normalization.
|
||||
|
||||
Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato,
|
||||
et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT-
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = spectral_normalization_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,315 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-like layers and utilities that implement Spectral Normalization.
|
||||
|
||||
Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato,
|
||||
et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT-
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import numbers
|
||||
import re
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
__all__ = [
|
||||
'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer',
|
||||
'spectral_normalization_custom_getter', 'keras_spectral_normalization'
|
||||
]
|
||||
|
||||
# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then
|
||||
# can't directly be assigned back to the tf.bfloat16 variable.
|
||||
_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u'
|
||||
|
||||
|
||||
def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None):
|
||||
"""Estimates the largest singular value in the weight tensor.
|
||||
|
||||
Args:
|
||||
w_tensor: The weight matrix whose spectral norm should be computed.
|
||||
power_iteration_rounds: The number of iterations of the power method to
|
||||
perform. A higher number yeilds a better approximation.
|
||||
name: An optional scope name.
|
||||
|
||||
Returns:
|
||||
The largest singular value (the spectral norm) of w.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'spectral_norm'):
|
||||
# The paper says to flatten convnet kernel weights from
|
||||
# (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D
|
||||
# kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to
|
||||
# (KH * KW * C_in, C_out), and similarly for other layers that put output
|
||||
# channels as last dimension.
|
||||
# n.b. this means that w here is equivalent to w.T in the paper.
|
||||
w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1]))
|
||||
|
||||
# Persisted approximation of first left singular vector of matrix `w`.
|
||||
u_var = variable_scope.get_variable(
|
||||
_PERSISTED_U_VARIABLE_SUFFIX,
|
||||
shape=(w.shape[0], 1),
|
||||
dtype=w.dtype,
|
||||
initializer=init_ops.random_normal_initializer(),
|
||||
trainable=False)
|
||||
u = u_var
|
||||
|
||||
# Use power iteration method to approximate spectral norm.
|
||||
for _ in range(power_iteration_rounds):
|
||||
# `v` approximates the first right singular vector of matrix `w`.
|
||||
v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u))
|
||||
u = nn.l2_normalize(math_ops.matmul(w, v))
|
||||
|
||||
# Update persisted approximation.
|
||||
with ops.control_dependencies([u_var.assign(u, name='update_u')]):
|
||||
u = array_ops.identity(u)
|
||||
|
||||
u = array_ops.stop_gradient(u)
|
||||
v = array_ops.stop_gradient(v)
|
||||
|
||||
# Largest singular value of `w`.
|
||||
spectral_norm = math_ops.matmul(
|
||||
math_ops.matmul(array_ops.transpose(u), w), v)
|
||||
spectral_norm.shape.assert_is_fully_defined()
|
||||
spectral_norm.shape.assert_is_compatible_with([1, 1])
|
||||
|
||||
return spectral_norm[0][0]
|
||||
|
||||
|
||||
def spectral_normalize(w, power_iteration_rounds=1, name=None):
|
||||
"""Normalizes a weight matrix by its spectral norm.
|
||||
|
||||
Args:
|
||||
w: The weight matrix to be normalized.
|
||||
power_iteration_rounds: The number of iterations of the power method to
|
||||
perform. A higher number yeilds a better approximation.
|
||||
name: An optional scope name.
|
||||
|
||||
Returns:
|
||||
A normalized weight matrix tensor.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'spectral_normalize'):
|
||||
w_normalized = w / compute_spectral_norm(
|
||||
w, power_iteration_rounds=power_iteration_rounds)
|
||||
return array_ops.reshape(w_normalized, w.get_shape())
|
||||
|
||||
|
||||
def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None):
|
||||
"""Returns a functions that can be used to apply spectral norm regularization.
|
||||
|
||||
Small spectral norms enforce a small Lipschitz constant, which is necessary
|
||||
for Wasserstein GANs.
|
||||
|
||||
Args:
|
||||
scale: A scalar multiplier. 0.0 disables the regularizer.
|
||||
power_iteration_rounds: The number of iterations of the power method to
|
||||
perform. A higher number yeilds a better approximation.
|
||||
scope: An optional scope name.
|
||||
|
||||
Returns:
|
||||
A function with the signature `sn(weights)` that applies spectral norm
|
||||
regularization.
|
||||
|
||||
Raises:
|
||||
ValueError: If scale is negative or if scale is not a float.
|
||||
"""
|
||||
if isinstance(scale, numbers.Integral):
|
||||
raise ValueError('scale cannot be an integer: %s' % scale)
|
||||
if isinstance(scale, numbers.Real):
|
||||
if scale < 0.0:
|
||||
raise ValueError(
|
||||
'Setting a scale less than 0 on a regularizer: %g' % scale)
|
||||
if scale == 0.0:
|
||||
logging.info('Scale of 0 disables regularizer.')
|
||||
return lambda _: None
|
||||
|
||||
def sn(weights, name=None):
|
||||
"""Applies spectral norm regularization to weights."""
|
||||
with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name:
|
||||
scale_t = ops.convert_to_tensor(
|
||||
scale, dtype=weights.dtype.base_dtype, name='scale')
|
||||
return math_ops.multiply(
|
||||
scale_t,
|
||||
compute_spectral_norm(
|
||||
weights, power_iteration_rounds=power_iteration_rounds),
|
||||
name=name)
|
||||
|
||||
return sn
|
||||
|
||||
|
||||
def _default_name_filter(name):
|
||||
"""A filter function to identify common names of weight variables.
|
||||
|
||||
Args:
|
||||
name: The variable name.
|
||||
|
||||
Returns:
|
||||
Whether `name` is a standard name for a weight/kernel variables used in the
|
||||
Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries.
|
||||
"""
|
||||
match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name)
|
||||
return match is not None
|
||||
|
||||
|
||||
def spectral_normalization_custom_getter(name_filter=_default_name_filter,
|
||||
power_iteration_rounds=1):
|
||||
"""Custom getter that performs Spectral Normalization on a weight tensor.
|
||||
|
||||
Specifically it divides the weight tensor by its largest singular value. This
|
||||
is intended to stabilize GAN training, by making the discriminator satisfy a
|
||||
local 1-Lipschitz constraint.
|
||||
|
||||
Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan].
|
||||
|
||||
[sn-gan]: https://openreview.net/forum?id=B1QRgziT-
|
||||
|
||||
To reproduce an SN-GAN, apply this custom_getter to every weight tensor of
|
||||
your discriminator. The last dimension of the weight tensor must be the number
|
||||
of output channels.
|
||||
|
||||
Apply this to layers by supplying this as the `custom_getter` of a
|
||||
`tf.variable_scope`. For example:
|
||||
|
||||
with tf.variable_scope('discriminator',
|
||||
custom_getter=spectral_norm_getter()):
|
||||
net = discriminator_fn(net)
|
||||
|
||||
IMPORTANT: Keras does not respect the custom_getter supplied by the
|
||||
VariableScope, so Keras users should use `keras_spectral_normalization`
|
||||
instead of (or in addition to) this approach.
|
||||
|
||||
It is important to carefully select to which weights you want to apply
|
||||
Spectral Normalization. In general you want to normalize the kernels of
|
||||
convolution and dense layers, but you do not want to normalize biases. You
|
||||
also want to avoid normalizing batch normalization (and similar) variables,
|
||||
but in general such layers play poorly with Spectral Normalization, since the
|
||||
gamma can cancel out the normalization in other layers. By default we supply a
|
||||
filter that matches the kernel variable names of the dense and convolution
|
||||
layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim
|
||||
libraries. If you are using anything else you'll need a custom `name_filter`.
|
||||
|
||||
This custom getter internally creates a variable used to compute the spectral
|
||||
norm by power iteration. It will update every time the variable is accessed,
|
||||
which means the normalized discriminator weights may change slightly whilst
|
||||
training the generator. Whilst unusual, this matches how the paper's authors
|
||||
implement it, and in general additional rounds of power iteration can't hurt.
|
||||
|
||||
Args:
|
||||
name_filter: Optionally, a method that takes a Variable name as input and
|
||||
returns whether this Variable should be normalized.
|
||||
power_iteration_rounds: The number of iterations of the power method to
|
||||
perform per step. A higher number yeilds a better approximation of the
|
||||
true spectral norm.
|
||||
|
||||
Returns:
|
||||
A custom getter function that applies Spectral Normalization to all
|
||||
Variables whose names match `name_filter`.
|
||||
|
||||
Raises:
|
||||
ValueError: If name_filter is not callable.
|
||||
"""
|
||||
if not callable(name_filter):
|
||||
raise ValueError('name_filter must be callable')
|
||||
|
||||
def _internal_getter(getter, name, *args, **kwargs):
|
||||
"""A custom getter function that applies Spectral Normalization.
|
||||
|
||||
Args:
|
||||
getter: The true getter to call.
|
||||
name: Name of new/existing variable, in the same format as
|
||||
tf.get_variable.
|
||||
*args: Other positional arguments, in the same format as tf.get_variable.
|
||||
**kwargs: Keyword arguments, in the same format as tf.get_variable.
|
||||
|
||||
Returns:
|
||||
The return value of `getter(name, *args, **kwargs)`, spectrally
|
||||
normalized.
|
||||
|
||||
Raises:
|
||||
ValueError: If used incorrectly, or if `dtype` is not supported.
|
||||
"""
|
||||
if not name_filter(name):
|
||||
return getter(name, *args, **kwargs)
|
||||
|
||||
if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX):
|
||||
raise ValueError(
|
||||
'Cannot apply Spectral Normalization to internal variables created '
|
||||
'for Spectral Normalization. Tried to normalized variable [%s]' %
|
||||
name)
|
||||
|
||||
if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM:
|
||||
raise ValueError('Disallowed data type {}'.format(kwargs['dtype']))
|
||||
|
||||
# This layer's weight Variable/PartitionedVariable.
|
||||
w_tensor = getter(name, *args, **kwargs)
|
||||
|
||||
if len(w_tensor.get_shape()) < 2:
|
||||
raise ValueError(
|
||||
'Spectral norm can only be applied to multi-dimensional tensors')
|
||||
|
||||
return spectral_normalize(
|
||||
w_tensor,
|
||||
power_iteration_rounds=power_iteration_rounds,
|
||||
name=(name + '/spectral_normalize'))
|
||||
|
||||
return _internal_getter
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def keras_spectral_normalization(name_filter=_default_name_filter,
|
||||
power_iteration_rounds=1):
|
||||
"""A context manager that enables Spectral Normalization for Keras.
|
||||
|
||||
Keras doesn't respect the `custom_getter` in the VariableScope, so this is a
|
||||
bit of a hack to make things work.
|
||||
|
||||
Usage:
|
||||
with keras_spectral_normalization():
|
||||
net = discriminator_fn(net)
|
||||
|
||||
Args:
|
||||
name_filter: Optionally, a method that takes a Variable name as input and
|
||||
returns whether this Variable should be normalized.
|
||||
power_iteration_rounds: The number of iterations of the power method to
|
||||
perform per step. A higher number yeilds a better approximation of the
|
||||
true spectral norm.
|
||||
|
||||
Yields:
|
||||
A context manager that wraps the standard Keras variable creation method
|
||||
with the `spectral_normalization_custom_getter`.
|
||||
"""
|
||||
original_make_variable = keras_base_layer_utils.make_variable
|
||||
sn_getter = spectral_normalization_custom_getter(
|
||||
name_filter=name_filter, power_iteration_rounds=power_iteration_rounds)
|
||||
|
||||
def make_variable_wrapper(name, *args, **kwargs):
|
||||
return sn_getter(original_make_variable, name, *args, **kwargs)
|
||||
|
||||
keras_base_layer_utils.make_variable = make_variable_wrapper
|
||||
|
||||
yield
|
||||
|
||||
keras_base_layer_utils.make_variable = original_make_variable
|
@ -0,0 +1,354 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for features.spectral_normalization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization
|
||||
from tensorflow.contrib.layers.python.layers import layers as contrib_layers
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras.layers import convolutional as keras_convolutional
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
from tensorflow.python.layers import convolutional as layers_convolutional
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SpectralNormalizationTest(test.TestCase):
|
||||
|
||||
def testComputeSpectralNorm(self):
|
||||
weights = variable_scope.get_variable(
|
||||
'w', dtype=dtypes.float32, shape=[2, 3, 50, 100])
|
||||
weights = math_ops.multiply(weights, 10.0)
|
||||
s = linalg_ops.svd(
|
||||
array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False)
|
||||
true_sn = s[..., 0]
|
||||
estimated_sn = spectral_normalization.compute_spectral_norm(weights)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
np_true_sn = sess.run(true_sn)
|
||||
for i in range(50):
|
||||
est = sess.run(estimated_sn)
|
||||
if i < 1:
|
||||
np_est_1 = est
|
||||
if i < 4:
|
||||
np_est_5 = est
|
||||
if i < 9:
|
||||
np_est_10 = est
|
||||
np_est_50 = est
|
||||
|
||||
# Check that the estimate improves with more iterations.
|
||||
self.assertAlmostEqual(np_true_sn, np_est_50, 0)
|
||||
self.assertGreater(
|
||||
abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50))
|
||||
self.assertGreater(
|
||||
abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10))
|
||||
self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5))
|
||||
|
||||
def testSpectralNormalize(self):
|
||||
weights = variable_scope.get_variable(
|
||||
'w', dtype=dtypes.float32, shape=[2, 3, 50, 100])
|
||||
weights = math_ops.multiply(weights, 10.0)
|
||||
normalized_weights = spectral_normalization.spectral_normalize(
|
||||
weights, power_iteration_rounds=1)
|
||||
|
||||
unnormalized_sigma = linalg_ops.svd(
|
||||
array_ops.reshape(weights, [-1, weights.shape[-1]]),
|
||||
compute_uv=False)[..., 0]
|
||||
normalized_sigma = linalg_ops.svd(
|
||||
array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]),
|
||||
compute_uv=False)[..., 0]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
s0 = sess.run(unnormalized_sigma)
|
||||
|
||||
for i in range(50):
|
||||
sigma = sess.run(normalized_sigma)
|
||||
if i < 1:
|
||||
s1 = sigma
|
||||
if i < 5:
|
||||
s5 = sigma
|
||||
if i < 10:
|
||||
s10 = sigma
|
||||
s50 = sigma
|
||||
|
||||
self.assertAlmostEqual(1., s50, 0)
|
||||
self.assertGreater(abs(s10 - 1.), abs(s50 - 1.))
|
||||
self.assertGreater(abs(s5 - 1.), abs(s10 - 1.))
|
||||
self.assertGreater(abs(s1 - 1.), abs(s5 - 1.))
|
||||
self.assertGreater(abs(s0 - 1.), abs(s1 - 1.))
|
||||
|
||||
def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False):
|
||||
x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3])
|
||||
|
||||
w_initial = np.random.randn(*w_shape) * 10
|
||||
w_initializer = init_ops.constant_initializer(w_initial)
|
||||
b_initial = np.random.randn(*b_shape)
|
||||
b_initializer = init_ops.constant_initializer(b_initial)
|
||||
|
||||
if is_keras:
|
||||
context_manager = spectral_normalization.keras_spectral_normalization()
|
||||
else:
|
||||
getter = spectral_normalization.spectral_normalization_custom_getter()
|
||||
context_manager = variable_scope.variable_scope('', custom_getter=getter)
|
||||
|
||||
with context_manager:
|
||||
(net,
|
||||
expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn(
|
||||
x, w_initializer, b_initializer)
|
||||
|
||||
x_data = np.random.rand(*x.shape)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
# Before running a forward pass we still expect the variables values to
|
||||
# differ from the initial value because of the normalizer.
|
||||
w_befores = []
|
||||
for name, var in expected_normalized_vars.items():
|
||||
w_before = sess.run(var)
|
||||
w_befores.append(w_before)
|
||||
self.assertFalse(
|
||||
np.allclose(w_initial, w_before),
|
||||
msg=('%s appears not to be normalized. Before: %s After: %s' %
|
||||
(name, w_initial, w_before)))
|
||||
|
||||
# Not true for the unnormalized variables.
|
||||
for name, var in expected_not_normalized_vars.items():
|
||||
b_before = sess.run(var)
|
||||
self.assertTrue(
|
||||
np.allclose(b_initial, b_before),
|
||||
msg=('%s appears to be unexpectedly normalized. '
|
||||
'Before: %s After: %s' % (name, b_initial, b_before)))
|
||||
|
||||
# Run a bunch of forward passes.
|
||||
for _ in range(1000):
|
||||
_ = sess.run(net, feed_dict={x: x_data})
|
||||
|
||||
# We expect this to have improved the estimate of the spectral norm,
|
||||
# which should have changed the variable values and brought them close
|
||||
# to the true Spectral Normalized values.
|
||||
_, s, _ = np.linalg.svd(w_initial.reshape([-1, 3]))
|
||||
exactly_normalized = w_initial / s[0]
|
||||
for w_before, (name, var) in zip(w_befores,
|
||||
expected_normalized_vars.items()):
|
||||
w_after = sess.run(var)
|
||||
self.assertFalse(
|
||||
np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8),
|
||||
msg=('%s did not improve over many iterations. '
|
||||
'Before: %s After: %s' % (name, w_before, w_after)))
|
||||
self.assertAllClose(
|
||||
exactly_normalized,
|
||||
w_after,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
msg=('Estimate of spectral norm for %s was innacurate. '
|
||||
'Normalized matrices do not match.'
|
||||
'Estimate: %s Actual: %s' % (name, w_after,
|
||||
exactly_normalized)))
|
||||
|
||||
def testConv2D_Layers(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
layer = layers_convolutional.Conv2D(
|
||||
filters=3,
|
||||
kernel_size=3,
|
||||
padding='same',
|
||||
kernel_initializer=w_initializer,
|
||||
bias_initializer=b_initializer)
|
||||
net = layer.apply(x)
|
||||
expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel}
|
||||
expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,))
|
||||
|
||||
def testConv2D_ContribLayers(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
var_collection = {
|
||||
'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'],
|
||||
'biases': ['CONTRIB_LAYERS_CONV2D_BIASES']
|
||||
}
|
||||
net = contrib_layers.conv2d(
|
||||
x,
|
||||
3,
|
||||
3,
|
||||
weights_initializer=w_initializer,
|
||||
biases_initializer=b_initializer,
|
||||
variables_collections=var_collection)
|
||||
weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS')
|
||||
self.assertEquals(1, len(weight_vars))
|
||||
bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES')
|
||||
self.assertEquals(1, len(bias_vars))
|
||||
expected_normalized_vars = {
|
||||
'contrib.layers.conv2d.weights': weight_vars[0]
|
||||
}
|
||||
expected_not_normalized_vars = {
|
||||
'contrib.layers.conv2d.bias': bias_vars[0]
|
||||
}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,))
|
||||
|
||||
def testConv2D_Slim(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
var_collection = {
|
||||
'weights': ['SLIM_CONV2D_WEIGHTS'],
|
||||
'biases': ['SLIM_CONV2D_BIASES']
|
||||
}
|
||||
net = slim.conv2d(
|
||||
x,
|
||||
3,
|
||||
3,
|
||||
weights_initializer=w_initializer,
|
||||
biases_initializer=b_initializer,
|
||||
variables_collections=var_collection)
|
||||
weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS')
|
||||
self.assertEquals(1, len(weight_vars))
|
||||
bias_vars = ops.get_collection('SLIM_CONV2D_BIASES')
|
||||
self.assertEquals(1, len(bias_vars))
|
||||
expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]}
|
||||
expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,))
|
||||
|
||||
def testConv2D_Keras(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
layer = keras_convolutional.Conv2D(
|
||||
filters=3,
|
||||
kernel_size=3,
|
||||
padding='same',
|
||||
kernel_initializer=w_initializer,
|
||||
bias_initializer=b_initializer)
|
||||
net = layer.apply(x)
|
||||
expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel}
|
||||
expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True)
|
||||
|
||||
def testFC_Layers(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
x = layers_core.Flatten()(x)
|
||||
layer = layers_core.Dense(
|
||||
units=3,
|
||||
kernel_initializer=w_initializer,
|
||||
bias_initializer=b_initializer)
|
||||
net = layer.apply(x)
|
||||
expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel}
|
||||
expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (300, 3), (3,))
|
||||
|
||||
def testFC_ContribLayers(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
var_collection = {
|
||||
'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'],
|
||||
'biases': ['CONTRIB_LAYERS_FC_BIASES']
|
||||
}
|
||||
x = contrib_layers.flatten(x)
|
||||
net = contrib_layers.fully_connected(
|
||||
x,
|
||||
3,
|
||||
weights_initializer=w_initializer,
|
||||
biases_initializer=b_initializer,
|
||||
variables_collections=var_collection)
|
||||
weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS')
|
||||
self.assertEquals(1, len(weight_vars))
|
||||
bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES')
|
||||
self.assertEquals(1, len(bias_vars))
|
||||
expected_normalized_vars = {
|
||||
'contrib.layers.fully_connected.weights': weight_vars[0]
|
||||
}
|
||||
expected_not_normalized_vars = {
|
||||
'contrib.layers.fully_connected.bias': bias_vars[0]
|
||||
}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (300, 3), (3,))
|
||||
|
||||
def testFC_Slim(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
var_collection = {
|
||||
'weights': ['SLIM_FC_WEIGHTS'],
|
||||
'biases': ['SLIM_FC_BIASES']
|
||||
}
|
||||
x = slim.flatten(x)
|
||||
net = slim.fully_connected(
|
||||
x,
|
||||
3,
|
||||
weights_initializer=w_initializer,
|
||||
biases_initializer=b_initializer,
|
||||
variables_collections=var_collection)
|
||||
weight_vars = ops.get_collection('SLIM_FC_WEIGHTS')
|
||||
self.assertEquals(1, len(weight_vars))
|
||||
bias_vars = ops.get_collection('SLIM_FC_BIASES')
|
||||
self.assertEquals(1, len(bias_vars))
|
||||
expected_normalized_vars = {
|
||||
'slim.fully_connected.weights': weight_vars[0]
|
||||
}
|
||||
expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (300, 3), (3,))
|
||||
|
||||
def testFC_Keras(self):
|
||||
|
||||
def build_layer_fn(x, w_initializer, b_initializer):
|
||||
x = keras_core.Flatten()(x)
|
||||
layer = keras_core.Dense(
|
||||
units=3,
|
||||
kernel_initializer=w_initializer,
|
||||
bias_initializer=b_initializer)
|
||||
net = layer.apply(x)
|
||||
expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel}
|
||||
expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias}
|
||||
|
||||
return net, expected_normalized_vars, expected_not_normalized_vars
|
||||
|
||||
self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user