Migrate TFGAN features to third_party.

PiperOrigin-RevId: 168060880
This commit is contained in:
A. Unique TensorFlower 2017-09-08 16:13:31 -07:00 committed by TensorFlower Gardener
parent d2ae1311f7
commit 48deb206ba
10 changed files with 1058 additions and 0 deletions

View File

@ -362,6 +362,7 @@ add_python_module("tensorflow/contrib/framework/python/framework")
add_python_module("tensorflow/contrib/framework/python/ops")
add_python_module("tensorflow/contrib/gan")
add_python_module("tensorflow/contrib/gan/python")
add_python_module("tensorflow/contrib/gan/python/features")
add_python_module("tensorflow/contrib/gan/python/losses")
add_python_module("tensorflow/contrib/graph_editor")
add_python_module("tensorflow/contrib/graph_editor/examples")

View File

@ -14,6 +14,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":features",
":losses",
],
)
@ -29,6 +30,18 @@ py_library(
],
)
py_library(
name = "features",
srcs = ["python/features/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":clip_weights",
":conditioning_utils",
":virtual_batchnorm",
"//tensorflow/python:util",
],
)
py_library(
name = "losses_impl",
srcs = ["python/losses/losses_impl.py"],
@ -96,6 +109,90 @@ py_test(
],
)
py_library(
name = "conditioning_utils",
srcs = ["python/features/conditioning_utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
],
)
py_test(
name = "conditioning_utils_test",
srcs = ["python/features/conditioning_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":conditioning_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
],
)
py_library(
name = "virtual_batchnorm",
srcs = ["python/features/virtual_batchnorm.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
],
)
py_test(
name = "virtual_batchnorm_test",
srcs = ["python/features/virtual_batchnorm_test.py"],
srcs_version = "PY2AND3",
deps = [
":virtual_batchnorm",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_library(
name = "clip_weights",
srcs = ["python/features/clip_weights.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow/contrib/opt:opt_py"],
)
py_test(
name = "clip_weights_test",
srcs = ["python/features/clip_weights_test.py"],
srcs_version = "PY2AND3",
deps = [
":clip_weights",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
del absolute_import

View File

@ -0,0 +1,37 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API. Please see README.md for details and usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Collapse features into a single namespace.
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.features import clip_weights
from tensorflow.contrib.gan.python.features import conditioning_utils
from tensorflow.contrib.gan.python.features import virtual_batchnorm
from tensorflow.contrib.gan.python.features.clip_weights import *
from tensorflow.contrib.gan.python.features.conditioning_utils import *
from tensorflow.contrib.gan.python.features.virtual_batchnorm import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = clip_weights.__all__
_allowed_symbols += conditioning_utils.__all__
_allowed_symbols += virtual_batchnorm.__all__
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,80 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to clip weights.
This is useful in the original formulation of the Wasserstein loss, which
requires that the discriminator be K-Lipschitz. See
https://arxiv.org/pdf/1701.07875 for more details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.opt.python.training import variable_clipping_optimizer
__all__ = [
'clip_variables',
'clip_discriminator_weights',
]
def clip_discriminator_weights(optimizer, model, weight_clip):
"""Modifies an optimizer so it clips weights to a certain value.
Args:
optimizer: An optimizer to perform variable weight clipping.
model: A GANModel namedtuple.
weight_clip: Positive python float to clip discriminator weights. Used to
enforce a K-lipschitz condition, which is useful for some GAN training
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
Returns:
An optimizer to perform weight clipping after updates.
Raises:
ValueError: If `weight_clip` is less than 0.
"""
return clip_variables(optimizer, model.discriminator_variables, weight_clip)
def clip_variables(optimizer, variables, weight_clip):
"""Modifies an optimizer so it clips weights to a certain value.
Args:
optimizer: An optimizer to perform variable weight clipping.
variables: A list of TensorFlow variables.
weight_clip: Positive python float to clip discriminator weights. Used to
enforce a K-lipschitz condition, which is useful for some GAN training
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
Returns:
An optimizer to perform weight clipping after updates.
Raises:
ValueError: If `weight_clip` is less than 0.
"""
if weight_clip < 0:
raise ValueError(
'`discriminator_weight_clip` must be positive. Instead, was %s',
weight_clip)
return variable_clipping_optimizer.VariableClippingOptimizer(
opt=optimizer,
# Do no reduction, so clipping happens per-value.
vars_to_clip_dims={var: [] for var in variables},
max_norm=weight_clip,
use_locking=True,
colocate_clip_ops_with_vars=True)

View File

@ -0,0 +1,81 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.python.features.clip_weights."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.contrib.gan.python.features import clip_weights
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import training
class ClipWeightsTest(test.TestCase):
"""Tests for `discriminator_weight_clip`."""
def setUp(self):
self.variables = [variables.Variable(2.0)]
self.tuple = collections.namedtuple(
'VarTuple', ['discriminator_variables'])(self.variables)
def _test_weight_clipping_helper(self, use_tuple):
loss = self.variables[0] * 2.0
opt = training.GradientDescentOptimizer(1.0)
if use_tuple:
opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1)
else:
opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1)
train_op1 = opt.minimize(loss, var_list=self.variables)
train_op2 = opt_clip.minimize(loss, var_list=self.variables)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(2.0, self.variables[0].eval())
sess.run(train_op1)
self.assertLess(0.1, self.variables[0].eval())
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(2.0, self.variables[0].eval())
sess.run(train_op2)
self.assertNear(0.1, self.variables[0].eval(), 1e-7)
def test_weight_clipping_argsonly(self):
self._test_weight_clipping_helper(False)
def test_weight_clipping_ganmodel(self):
self._test_weight_clipping_helper(True)
def _test_incorrect_weight_clip_value_helper(self, use_tuple):
opt = training.GradientDescentOptimizer(1.0)
if use_tuple:
with self.assertRaisesRegexp(ValueError, 'must be positive'):
clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1)
else:
with self.assertRaisesRegexp(ValueError, 'must be positive'):
clip_weights.clip_weights(opt, self.variables, weight_clip=-1)
def test_incorrect_weight_clip_value_argsonly(self):
self._test_incorrect_weight_clip_value_helper(False)
def test_incorrect_weight_clip_value_tuple(self):
self._test_incorrect_weight_clip_value_helper(True)

View File

@ -0,0 +1,112 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellanous utilities for TFGAN code and examples.
Includes:
1) Conditioning the value of a Tensor, based on techniques from
https://arxiv.org/abs/1609.03499.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
__all__ = [
'condition_tensor',
'condition_tensor_from_onehot',
]
def _get_shape(tensor):
tensor_shape = array_ops.shape(tensor)
static_tensor_shape = tensor_util.constant_value(tensor_shape)
return (static_tensor_shape if static_tensor_shape is not None else
tensor_shape)
def condition_tensor(tensor, conditioning):
"""Condition the value of a tensor.
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
Args:
tensor: A minibatch tensor to be conditioned.
conditioning: A minibatch Tensor of to condition on. Must be 2D, with first
dimension the same as `tensor`.
Returns:
`tensor` conditioned on `conditioning`.
Raises:
ValueError: If the non-batch dimensions of `tensor` aren't fully defined.
ValueError: If `conditioning` isn't at least 2D.
ValueError: If the batch dimension for the input Tensors don't match.
"""
tensor.shape[1:].assert_is_fully_defined()
num_features = tensor.shape[1:].num_elements()
mapped_conditioning = layers.linear(
layers.flatten(conditioning), num_features)
if not mapped_conditioning.shape.is_compatible_with(tensor.shape):
mapped_conditioning = array_ops.reshape(
mapped_conditioning, _get_shape(tensor))
return tensor + mapped_conditioning
def _one_hot_to_embedding(one_hot, embedding_size):
"""Get a dense embedding vector from a one-hot encoding."""
num_tokens = one_hot.shape[1]
label_id = math_ops.argmax(one_hot, axis=1)
embedding = variable_scope.get_variable(
'embedding', [num_tokens, embedding_size])
return embedding_ops.embedding_lookup(
embedding, label_id, name='token_to_embedding')
def _validate_onehot(one_hot_labels):
one_hot_labels.shape.assert_has_rank(2)
one_hot_labels.shape[1:].assert_is_fully_defined()
def condition_tensor_from_onehot(tensor, one_hot_labels, embedding_size=256):
"""Condition a tensor based on a one-hot tensor.
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
Args:
tensor: Tensor to be conditioned.
one_hot_labels: A Tensor of one-hot labels. Shape is
[batch_size, num_classes].
embedding_size: The size of the class embedding.
Returns:
`tensor` conditioned on `one_hot_labels`.
Raises:
ValueError: `one_hot_labels` isn't 2D, if non-batch dimensions aren't
fully defined, or if batch sizes don't match.
"""
_validate_onehot(one_hot_labels)
conditioning = _one_hot_to_embedding(one_hot_labels, embedding_size)
return condition_tensor(tensor, conditioning)

View File

@ -0,0 +1,76 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.python.features.conditioning_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features import conditioning_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ConditioningUtilsTest(test.TestCase):
def test_condition_tensor_multiple_shapes(self):
for tensor_shape in [(4, 1), (4, 2), (4, 2, 6), (None, 5, 3)]:
for conditioning_shape in [(4, 1), (4, 8), (4, 5, 3)]:
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, tensor_shape),
array_ops.placeholder(dtypes.float32, conditioning_shape))
def test_condition_tensor_asserts(self):
with self.assertRaisesRegexp(ValueError, 'Cannot reshape'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (4, 1)),
array_ops.placeholder(dtypes.float32, (5, 1)))
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, None)),
array_ops.placeholder(dtypes.float32, (5, 1)))
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, 2)),
array_ops.placeholder(dtypes.float32, (5)))
def test_condition_tensor_from_onehot(self):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 4, 1)),
array_ops.placeholder(dtypes.float32, (5, 10)))
def test_condition_tensor_from_onehot_asserts(self):
with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 2'):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 1)),
array_ops.placeholder(dtypes.float32, (5)))
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 1)),
array_ops.placeholder(dtypes.float32, (5, None)))
with self.assertRaisesRegexp(ValueError, 'Cannot reshape a tensor'):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 1)),
array_ops.placeholder(dtypes.float32, (4, 6)))
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,306 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Virtual batch normalization.
This technique was first introduced in `Improved Techniques for Training GANs`
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
normalization on a minibatch, it fixes a reference subset of the data to use for
calculating normalization statistics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
__all__ = [
'VBN',
]
def _static_or_dynamic_batch_size(tensor, batch_axis):
"""Returns the static or dynamic batch size."""
batch_size = array_ops.shape(tensor)[batch_axis]
static_batch_size = tensor_util.constant_value(batch_size)
return static_batch_size or batch_size
def _statistics(x, axes):
"""Calculate the mean and mean square of `x`.
Modified from the implementation of `tf.nn.moments`.
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and
variance.
Returns:
Two `Tensor` objects: `mean` and `square mean`.
"""
# The dynamic range of fp16 is too limited to support the collection of
# sufficient statistics. As a workaround we simply perform the operations
# on 32-bit floats before converting the mean and variance back to fp16
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
# Compute true mean while keeping the dims for proper broadcasting.
shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
mean = shifted_mean + shift
mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
mean = array_ops.squeeze(mean, axes)
mean_squared = array_ops.squeeze(mean_squared, axes)
if x.dtype == dtypes.float16:
return (math_ops.cast(mean, dtypes.float16),
math_ops.cast(mean_squared, dtypes.float16))
else:
return (mean, mean_squared)
def _validate_init_input_and_get_axis(reference_batch, axis):
"""Validate input and return the used axis value."""
if reference_batch.shape.ndims is None:
raise ValueError('`reference_batch` has unknown dimensions.')
ndims = reference_batch.shape.ndims
if axis < 0:
used_axis = ndims + axis
else:
used_axis = axis
if used_axis < 0 or used_axis >= ndims:
raise ValueError('Value of `axis` argument ' + str(used_axis) +
' is out of range for input with rank ' + str(ndims))
return used_axis
def _validate_call_input(tensor_list, batch_dim):
"""Verifies that tensor shapes are compatible, except for `batch_dim`."""
def _get_shape(tensor):
shape = tensor.shape.as_list()
del shape[batch_dim]
return shape
base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
for tensor in tensor_list:
base_shape.assert_is_compatible_with(_get_shape(tensor))
class VBN(object):
"""A class to perform virtual batch normalization.
This technique was first introduced in `Improved Techniques for Training GANs`
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
normalization on a minibatch, it fixes a reference subset of the data to use
for calculating normalization statistics.
To do this, we calculate the reference batch mean and mean square, and modify
those statistics for each example. We use mean square instead of variance,
since it is linear.
Note that if `center` or `scale` variables are created, they are shared
between all calls to this object.
The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
closely as possible.
"""
def __init__(self,
reference_batch,
axis=-1,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer=init_ops.zeros_initializer(),
gamma_initializer=init_ops.ones_initializer(),
beta_regularizer=None,
gamma_regularizer=None,
trainable=True,
name=None,
batch_axis=0):
"""Initialize virtual batch normalization object.
We precompute the 'mean' and 'mean squared' of the reference batch, so that
`__call__` is efficient. This means that the axis must be supplied when the
object is created, not when it is called.
We precompute 'square mean' instead of 'variance', because the square mean
can be easily adjusted on a per-example basis.
Args:
reference_batch: A minibatch tensors. This will form the reference data
from which the normalization statistics are calculated. See
https://arxiv.org/abs/1606.03498 for more details.
axis: Integer, the axis that should be normalized (typically the features
axis). For instance, after a `Convolution2D` layer with
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor. If False,
`beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can
be disabled since the scaling can be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: String, the name of the ops.
batch_axis: The axis of the batch dimension. This dimension is treated
differently in `virtual batch normalization` vs `batch normalization`.
Raises:
ValueError: If `reference_batch` has unknown dimensions at graph
construction.
ValueError: If `batch_axis` is the same as `axis`.
"""
axis = _validate_init_input_and_get_axis(reference_batch, axis)
self._epsilon = epsilon
self._beta = 0
self._gamma = 1
self._batch_axis = _validate_init_input_and_get_axis(
reference_batch, batch_axis)
if axis == self._batch_axis:
raise ValueError('`axis` and `batch_axis` cannot be the same.')
with variable_scope.variable_scope(name, 'VBN',
values=[reference_batch]) as self._vs:
self._reference_batch = reference_batch
# Calculate important shapes:
# 1) Reduction axes for the reference batch
# 2) Broadcast shape, if necessary
# 3) Reduction axes for the virtual batchnormed batch
# 4) Shape for optional parameters
input_shape = self._reference_batch.shape
ndims = input_shape.ndims
reduction_axes = list(range(ndims))
del reduction_axes[axis]
self._broadcast_shape = [1] * len(input_shape)
self._broadcast_shape[axis] = input_shape[axis].value
self._example_reduction_axes = list(range(ndims))
del self._example_reduction_axes[max(axis, self._batch_axis)]
del self._example_reduction_axes[min(axis, self._batch_axis)]
params_shape = self._reference_batch.shape[axis]
# Determines whether broadcasting is needed. This is slightly different
# than in the `nn.batch_normalization` case, due to `batch_dim`.
self._needs_broadcasting = (
sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
# Calculate the sufficient statistics for the reference batch in a way
# that can be easily modified by additional examples.
self._ref_mean, self._ref_mean_squares = _statistics(
self._reference_batch, reduction_axes)
self._ref_variance = (self._ref_mean_squares -
math_ops.square(self._ref_mean))
# Virtual batch normalization uses a weighted average between example
# statistics and the reference batch statistics.
ref_batch_size = _static_or_dynamic_batch_size(
self._reference_batch, self._batch_axis)
self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
self._ref_weight = 1. - self._example_weight
# Make the variables, if necessary.
if center:
self._beta = variable_scope.get_variable(
name='beta',
shape=(params_shape,),
initializer=beta_initializer,
regularizer=beta_regularizer,
trainable=trainable)
if scale:
self._gamma = variable_scope.get_variable(
name='gamma',
shape=(params_shape,),
initializer=gamma_initializer,
regularizer=gamma_regularizer,
trainable=trainable)
def _virtual_statistics(self, inputs, reduction_axes):
"""Compute the statistics needed for virtual batch normalization."""
cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
vb_mean = (self._example_weight * cur_mean +
self._ref_weight * self._ref_mean)
vb_mean_sq = (self._example_weight * cur_mean_sq +
self._ref_weight * self._ref_mean_squares)
return (vb_mean, vb_mean_sq)
def _broadcast(self, v, broadcast_shape=None):
# The exact broadcast shape depends on the current batch, not the reference
# batch, unless we're calculating the batch normalization of the reference
# batch.
b_shape = broadcast_shape or self._broadcast_shape
if self._needs_broadcasting and v is not None:
return array_ops.reshape(v, b_shape)
return v
def reference_batch_normalization(self):
"""Return the reference batch, but batch normalized."""
with ops.name_scope(self._vs.name):
return nn.batch_normalization(self._reference_batch,
self._broadcast(self._ref_mean),
self._broadcast(self._ref_variance),
self._broadcast(self._beta),
self._broadcast(self._gamma),
self._epsilon)
def __call__(self, inputs):
"""Run virtual batch normalization on inputs.
Args:
inputs: Tensor input.
Returns:
A virtual batch normalized version of `inputs`.
Raises:
ValueError: If `inputs` shape isn't compatible with the reference batch.
"""
_validate_call_input([inputs, self._reference_batch], self._batch_axis)
with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
# Calculate the statistics on the current input on a per-example basis.
vb_mean, vb_mean_sq = self._virtual_statistics(
inputs, self._example_reduction_axes)
vb_variance = vb_mean_sq - math_ops.square(vb_mean)
# The exact broadcast shape of the input statistic Tensors depends on the
# current batch, not the reference batch. The parameter broadcast shape
# is independent of the shape of the input statistic Tensor dimensions.
b_shape = self._broadcast_shape[:] # deep copy
b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
inputs, self._batch_axis)
return nn.batch_normalization(
inputs,
self._broadcast(vb_mean, b_shape),
self._broadcast(vb_variance, b_shape),
self._broadcast(self._beta, self._broadcast_shape),
self._broadcast(self._gamma, self._broadcast_shape),
self._epsilon)

View File

@ -0,0 +1,267 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.python.features.virtual_batchnorm."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
from tensorflow.contrib.gan.python.features import virtual_batchnorm
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
class VirtualBatchnormTest(test.TestCase):
def test_syntax(self):
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
vbn = virtual_batchnorm.VBN(reference_batch, batch_axis=1)
vbn(array_ops.ones([5, 7, 16, 9, 15]))
def test_no_broadcast_needed(self):
"""When `axis` and `batch_axis` are at the end, no broadcast is needed."""
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
minibatch = array_ops.zeros([5, 3, 16, 3, 15])
vbn = virtual_batchnorm.VBN(reference_batch, axis=-1, batch_axis=-2)
vbn(minibatch)
def test_statistics(self):
"""Check that `_statistics` gives the same result as `nn.moments`."""
random_seed.set_random_seed(1234)
tensors = random_ops.random_normal([4, 5, 7, 3])
for axes in [(3), (0, 2), (1, 2, 3)]:
vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
mom_mean, mom_var = nn.moments(tensors, axes)
vb_var = mean_sq - math_ops.square(vb_mean)
with self.test_session(use_gpu=True) as sess:
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
vb_mean, vb_var, mom_mean, mom_var])
self.assertAllClose(mom_mean_np, vb_mean_np)
self.assertAllClose(mom_var_np, vb_var_np)
def test_virtual_statistics(self):
"""Check that `_virtual_statistics` gives same result as `nn.moments`."""
random_seed.set_random_seed(1234)
batch_axis = 0
partial_batch = random_ops.random_normal([4, 5, 7, 3])
single_example = random_ops.random_normal([1, 5, 7, 3])
full_batch = array_ops.concat([partial_batch, single_example], axis=0)
for reduction_axis in range(1, 4):
# Get `nn.moments` on the full batch.
reduction_axes = list(range(4))
del reduction_axes[reduction_axis]
mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)
# Get virtual batch statistics.
vb_reduction_axes = list(range(4))
del vb_reduction_axes[reduction_axis]
del vb_reduction_axes[batch_axis]
vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
vb_mean, mean_sq = vbn._virtual_statistics(
single_example, vb_reduction_axes)
vb_variance = mean_sq - math_ops.square(vb_mean)
# Remove singleton batch dim for easy comparisons.
vb_mean = array_ops.squeeze(vb_mean, batch_axis)
vb_variance = array_ops.squeeze(vb_variance, batch_axis)
with self.test_session(use_gpu=True) as sess:
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
vb_mean, vb_variance, mom_mean, mom_variance])
self.assertAllClose(mom_mean_np, vb_mean_np)
self.assertAllClose(mom_var_np, vb_var_np)
def test_reference_batch_normalization(self):
"""Check that batch norm from VBN agrees with opensource implementation."""
random_seed.set_random_seed(1234)
batch = random_ops.random_normal([6, 5, 7, 3, 3])
for axis in range(5):
# Get `layers` batchnorm result.
bn_normalized = normalization.batch_normalization(
batch, axis, training=True)
# Get VBN's batch normalization on reference batch.
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
vbn_normalized = vbn.reference_batch_normalization()
with self.test_session(use_gpu=True) as sess:
variables_lib.global_variables_initializer().run()
bn_normalized_np, vbn_normalized_np = sess.run(
[bn_normalized, vbn_normalized])
self.assertAllClose(bn_normalized_np, vbn_normalized_np)
def test_same_as_batchnorm(self):
"""Check that batch norm on set X is the same as ref of X / y on `y`."""
random_seed.set_random_seed(1234)
num_examples = 4
examples = [random_ops.random_normal([5, 7, 3]) for _ in
range(num_examples)]
# Get the result of the opensource batch normalization.
batch_normalized = normalization.batch_normalization(
array_ops.stack(examples), training=True)
for i in range(num_examples):
examples_except_i = array_ops.stack(examples[:i] + examples[i+1:])
# Get the result of VBN's batch normalization.
vbn = virtual_batchnorm.VBN(examples_except_i)
vb_normed = array_ops.squeeze(
vbn(array_ops.expand_dims(examples[i], [0])), [0])
with self.test_session(use_gpu=True) as sess:
variables_lib.global_variables_initializer().run()
bn_np, vb_np = sess.run([batch_normalized, vb_normed])
self.assertAllClose(bn_np[i, ...], vb_np)
def test_minibatch_independent(self):
"""Test that virtual batch normalized exampels are independent.
Unlike batch normalization, virtual batch normalization has the property
that the virtual batch normalized value of an example is independent of the
other examples in the minibatch. In this test, we verify this property.
"""
random_seed.set_random_seed(1234)
# These can be random, but must be the same for all session calls.
reference_batch = constant_op.constant(
np.random.normal(size=[4, 7, 3]), dtype=dtypes.float32)
fixed_example = constant_op.constant(np.random.normal(size=[7, 3]),
dtype=dtypes.float32)
# Get the VBN object and the virtual batch normalized value for
# `fixed_example`.
vbn = virtual_batchnorm.VBN(reference_batch)
vbn_fixed_example = array_ops.squeeze(
vbn(array_ops.expand_dims(fixed_example, 0)), 0)
with self.test_session(use_gpu=True):
variables_lib.global_variables_initializer().run()
vbn_fixed_example_np = vbn_fixed_example.eval()
# Check that the value is the same for different minibatches, and different
# sized minibatches.
for minibatch_size in range(1, 6):
examples = [random_ops.random_normal([7, 3]) for _ in
range(minibatch_size)]
minibatch = array_ops.stack([fixed_example] + examples)
vbn_minibatch = vbn(minibatch)
cur_vbn_fixed_example = vbn_minibatch[0, ...]
with self.test_session(use_gpu=True):
variables_lib.global_variables_initializer().run()
cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval()
self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np)
def test_variable_reuse(self):
"""Test that variable scopes work and inference on a real-ish case."""
tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3])
tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3])
tensor2_ref = array_ops.zeros([4, 2, 3])
tensor2_examples = array_ops.zeros([2, 2, 3])
with variable_scope.variable_scope('dummy_scope', reuse=True):
with self.assertRaisesRegexp(
ValueError, 'does not exist, or was not created with '
'tf.get_variable()'):
virtual_batchnorm.VBN(tensor1_ref)
vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1')
vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2')
# Fetch reference and examples after virtual batch normalization. Also
# fetch in variable reuse case.
to_fetch = []
to_fetch.append(vbn1.reference_batch_normalization())
to_fetch.append(vbn2.reference_batch_normalization())
to_fetch.append(vbn1(tensor1_examples))
to_fetch.append(vbn2(tensor2_examples))
variable_scope.get_variable_scope().reuse_variables()
to_fetch.append(vbn1.reference_batch_normalization())
to_fetch.append(vbn2.reference_batch_normalization())
to_fetch.append(vbn1(tensor1_examples))
to_fetch.append(vbn2(tensor2_examples))
self.assertEqual(4, len(contrib_variables_lib.get_variables()))
with self.test_session(use_gpu=True) as sess:
variables_lib.global_variables_initializer().run()
sess.run(to_fetch)
def test_invalid_input(self):
# Reference batch has unknown dimensions.
with self.assertRaisesRegexp(
ValueError, '`reference_batch` has unknown dimensions.'):
virtual_batchnorm.VBN(array_ops.placeholder(dtypes.float32), name='vbn1')
# Axis too negative.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=-3, name='vbn2')
# Axis too large.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=2, name='vbn3')
# Batch axis too negative.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn4', batch_axis=-3)
# Batch axis too large.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn5', batch_axis=2)
# Axis and batch axis are the same.
with self.assertRaisesRegexp(
ValueError, '`axis` and `batch_axis` cannot be the same.'):
virtual_batchnorm.VBN(array_ops.zeros(
[1, 2]), axis=1, name='vbn6', batch_axis=1)
# Reference Tensor and example Tensor have incompatible shapes.
tensor_ref = array_ops.zeros([5, 2, 3])
tensor_examples = array_ops.zeros([3, 2, 3])
vbn = virtual_batchnorm.VBN(tensor_ref, name='vbn7', batch_axis=1)
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
vbn(tensor_examples)
if __name__ == '__main__':
test.main()