Migrate TFGAN features to third_party.
PiperOrigin-RevId: 168060880
This commit is contained in:
parent
d2ae1311f7
commit
48deb206ba
@ -362,6 +362,7 @@ add_python_module("tensorflow/contrib/framework/python/framework")
|
||||
add_python_module("tensorflow/contrib/framework/python/ops")
|
||||
add_python_module("tensorflow/contrib/gan")
|
||||
add_python_module("tensorflow/contrib/gan/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/features")
|
||||
add_python_module("tensorflow/contrib/gan/python/losses")
|
||||
add_python_module("tensorflow/contrib/graph_editor")
|
||||
add_python_module("tensorflow/contrib/graph_editor/examples")
|
||||
|
@ -14,6 +14,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":features",
|
||||
":losses",
|
||||
],
|
||||
)
|
||||
@ -29,6 +30,18 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "features",
|
||||
srcs = ["python/features/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":clip_weights",
|
||||
":conditioning_utils",
|
||||
":virtual_batchnorm",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "losses_impl",
|
||||
srcs = ["python/losses/losses_impl.py"],
|
||||
@ -96,6 +109,90 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "conditioning_utils",
|
||||
srcs = ["python/features/conditioning_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "conditioning_utils_test",
|
||||
srcs = ["python/features/conditioning_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":conditioning_utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "virtual_batchnorm",
|
||||
srcs = ["python/features/virtual_batchnorm.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "virtual_batchnorm_test",
|
||||
srcs = ["python/features/virtual_batchnorm_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":virtual_batchnorm",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "clip_weights",
|
||||
srcs = ["python/features/clip_weights.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/contrib/opt:opt_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "clip_weights_test",
|
||||
srcs = ["python/features/clip_weights_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":clip_weights",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Collapse TFGAN into a tiered namespace.
|
||||
from tensorflow.contrib.gan.python import features
|
||||
from tensorflow.contrib.gan.python import losses
|
||||
|
||||
del absolute_import
|
||||
|
37
tensorflow/contrib/gan/python/features/__init__.py
Normal file
37
tensorflow/contrib/gan/python/features/__init__.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Collapse features into a single namespace.
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.gan.python.features import clip_weights
|
||||
from tensorflow.contrib.gan.python.features import conditioning_utils
|
||||
from tensorflow.contrib.gan.python.features import virtual_batchnorm
|
||||
|
||||
from tensorflow.contrib.gan.python.features.clip_weights import *
|
||||
from tensorflow.contrib.gan.python.features.conditioning_utils import *
|
||||
from tensorflow.contrib.gan.python.features.virtual_batchnorm import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = clip_weights.__all__
|
||||
_allowed_symbols += conditioning_utils.__all__
|
||||
_allowed_symbols += virtual_batchnorm.__all__
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
80
tensorflow/contrib/gan/python/features/clip_weights.py
Normal file
80
tensorflow/contrib/gan/python/features/clip_weights.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities to clip weights.
|
||||
|
||||
This is useful in the original formulation of the Wasserstein loss, which
|
||||
requires that the discriminator be K-Lipschitz. See
|
||||
https://arxiv.org/pdf/1701.07875 for more details.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.opt.python.training import variable_clipping_optimizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
'clip_variables',
|
||||
'clip_discriminator_weights',
|
||||
]
|
||||
|
||||
|
||||
def clip_discriminator_weights(optimizer, model, weight_clip):
|
||||
"""Modifies an optimizer so it clips weights to a certain value.
|
||||
|
||||
Args:
|
||||
optimizer: An optimizer to perform variable weight clipping.
|
||||
model: A GANModel namedtuple.
|
||||
weight_clip: Positive python float to clip discriminator weights. Used to
|
||||
enforce a K-lipschitz condition, which is useful for some GAN training
|
||||
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
|
||||
|
||||
Returns:
|
||||
An optimizer to perform weight clipping after updates.
|
||||
|
||||
Raises:
|
||||
ValueError: If `weight_clip` is less than 0.
|
||||
"""
|
||||
return clip_variables(optimizer, model.discriminator_variables, weight_clip)
|
||||
|
||||
|
||||
def clip_variables(optimizer, variables, weight_clip):
|
||||
"""Modifies an optimizer so it clips weights to a certain value.
|
||||
|
||||
Args:
|
||||
optimizer: An optimizer to perform variable weight clipping.
|
||||
variables: A list of TensorFlow variables.
|
||||
weight_clip: Positive python float to clip discriminator weights. Used to
|
||||
enforce a K-lipschitz condition, which is useful for some GAN training
|
||||
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
|
||||
|
||||
Returns:
|
||||
An optimizer to perform weight clipping after updates.
|
||||
|
||||
Raises:
|
||||
ValueError: If `weight_clip` is less than 0.
|
||||
"""
|
||||
if weight_clip < 0:
|
||||
raise ValueError(
|
||||
'`discriminator_weight_clip` must be positive. Instead, was %s',
|
||||
weight_clip)
|
||||
return variable_clipping_optimizer.VariableClippingOptimizer(
|
||||
opt=optimizer,
|
||||
# Do no reduction, so clipping happens per-value.
|
||||
vars_to_clip_dims={var: [] for var in variables},
|
||||
max_norm=weight_clip,
|
||||
use_locking=True,
|
||||
colocate_clip_ops_with_vars=True)
|
81
tensorflow/contrib/gan/python/features/clip_weights_test.py
Normal file
81
tensorflow/contrib/gan/python/features/clip_weights_test.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tfgan.python.features.clip_weights."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.contrib.gan.python.features import clip_weights
|
||||
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training
|
||||
|
||||
|
||||
class ClipWeightsTest(test.TestCase):
|
||||
"""Tests for `discriminator_weight_clip`."""
|
||||
|
||||
def setUp(self):
|
||||
self.variables = [variables.Variable(2.0)]
|
||||
self.tuple = collections.namedtuple(
|
||||
'VarTuple', ['discriminator_variables'])(self.variables)
|
||||
|
||||
def _test_weight_clipping_helper(self, use_tuple):
|
||||
loss = self.variables[0] * 2.0
|
||||
opt = training.GradientDescentOptimizer(1.0)
|
||||
if use_tuple:
|
||||
opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1)
|
||||
else:
|
||||
opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1)
|
||||
|
||||
train_op1 = opt.minimize(loss, var_list=self.variables)
|
||||
train_op2 = opt_clip.minimize(loss, var_list=self.variables)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertEqual(2.0, self.variables[0].eval())
|
||||
sess.run(train_op1)
|
||||
self.assertLess(0.1, self.variables[0].eval())
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertEqual(2.0, self.variables[0].eval())
|
||||
sess.run(train_op2)
|
||||
self.assertNear(0.1, self.variables[0].eval(), 1e-7)
|
||||
|
||||
def test_weight_clipping_argsonly(self):
|
||||
self._test_weight_clipping_helper(False)
|
||||
|
||||
def test_weight_clipping_ganmodel(self):
|
||||
self._test_weight_clipping_helper(True)
|
||||
|
||||
def _test_incorrect_weight_clip_value_helper(self, use_tuple):
|
||||
opt = training.GradientDescentOptimizer(1.0)
|
||||
|
||||
if use_tuple:
|
||||
with self.assertRaisesRegexp(ValueError, 'must be positive'):
|
||||
clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1)
|
||||
else:
|
||||
with self.assertRaisesRegexp(ValueError, 'must be positive'):
|
||||
clip_weights.clip_weights(opt, self.variables, weight_clip=-1)
|
||||
|
||||
def test_incorrect_weight_clip_value_argsonly(self):
|
||||
self._test_incorrect_weight_clip_value_helper(False)
|
||||
|
||||
def test_incorrect_weight_clip_value_tuple(self):
|
||||
self._test_incorrect_weight_clip_value_helper(True)
|
112
tensorflow/contrib/gan/python/features/conditioning_utils.py
Normal file
112
tensorflow/contrib/gan/python/features/conditioning_utils.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Miscellanous utilities for TFGAN code and examples.
|
||||
|
||||
Includes:
|
||||
1) Conditioning the value of a Tensor, based on techniques from
|
||||
https://arxiv.org/abs/1609.03499.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
__all__ = [
|
||||
'condition_tensor',
|
||||
'condition_tensor_from_onehot',
|
||||
]
|
||||
|
||||
|
||||
def _get_shape(tensor):
|
||||
tensor_shape = array_ops.shape(tensor)
|
||||
static_tensor_shape = tensor_util.constant_value(tensor_shape)
|
||||
return (static_tensor_shape if static_tensor_shape is not None else
|
||||
tensor_shape)
|
||||
|
||||
|
||||
def condition_tensor(tensor, conditioning):
|
||||
"""Condition the value of a tensor.
|
||||
|
||||
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
|
||||
|
||||
Args:
|
||||
tensor: A minibatch tensor to be conditioned.
|
||||
conditioning: A minibatch Tensor of to condition on. Must be 2D, with first
|
||||
dimension the same as `tensor`.
|
||||
|
||||
Returns:
|
||||
`tensor` conditioned on `conditioning`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the non-batch dimensions of `tensor` aren't fully defined.
|
||||
ValueError: If `conditioning` isn't at least 2D.
|
||||
ValueError: If the batch dimension for the input Tensors don't match.
|
||||
"""
|
||||
tensor.shape[1:].assert_is_fully_defined()
|
||||
num_features = tensor.shape[1:].num_elements()
|
||||
|
||||
mapped_conditioning = layers.linear(
|
||||
layers.flatten(conditioning), num_features)
|
||||
if not mapped_conditioning.shape.is_compatible_with(tensor.shape):
|
||||
mapped_conditioning = array_ops.reshape(
|
||||
mapped_conditioning, _get_shape(tensor))
|
||||
return tensor + mapped_conditioning
|
||||
|
||||
|
||||
def _one_hot_to_embedding(one_hot, embedding_size):
|
||||
"""Get a dense embedding vector from a one-hot encoding."""
|
||||
num_tokens = one_hot.shape[1]
|
||||
label_id = math_ops.argmax(one_hot, axis=1)
|
||||
embedding = variable_scope.get_variable(
|
||||
'embedding', [num_tokens, embedding_size])
|
||||
return embedding_ops.embedding_lookup(
|
||||
embedding, label_id, name='token_to_embedding')
|
||||
|
||||
|
||||
def _validate_onehot(one_hot_labels):
|
||||
one_hot_labels.shape.assert_has_rank(2)
|
||||
one_hot_labels.shape[1:].assert_is_fully_defined()
|
||||
|
||||
|
||||
def condition_tensor_from_onehot(tensor, one_hot_labels, embedding_size=256):
|
||||
"""Condition a tensor based on a one-hot tensor.
|
||||
|
||||
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be conditioned.
|
||||
one_hot_labels: A Tensor of one-hot labels. Shape is
|
||||
[batch_size, num_classes].
|
||||
embedding_size: The size of the class embedding.
|
||||
|
||||
Returns:
|
||||
`tensor` conditioned on `one_hot_labels`.
|
||||
|
||||
Raises:
|
||||
ValueError: `one_hot_labels` isn't 2D, if non-batch dimensions aren't
|
||||
fully defined, or if batch sizes don't match.
|
||||
"""
|
||||
_validate_onehot(one_hot_labels)
|
||||
|
||||
conditioning = _one_hot_to_embedding(one_hot_labels, embedding_size)
|
||||
return condition_tensor(tensor, conditioning)
|
@ -0,0 +1,76 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tfgan.python.features.conditioning_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.features import conditioning_utils
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ConditioningUtilsTest(test.TestCase):
|
||||
|
||||
def test_condition_tensor_multiple_shapes(self):
|
||||
for tensor_shape in [(4, 1), (4, 2), (4, 2, 6), (None, 5, 3)]:
|
||||
for conditioning_shape in [(4, 1), (4, 8), (4, 5, 3)]:
|
||||
conditioning_utils.condition_tensor(
|
||||
array_ops.placeholder(dtypes.float32, tensor_shape),
|
||||
array_ops.placeholder(dtypes.float32, conditioning_shape))
|
||||
|
||||
def test_condition_tensor_asserts(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Cannot reshape'):
|
||||
conditioning_utils.condition_tensor(
|
||||
array_ops.placeholder(dtypes.float32, (4, 1)),
|
||||
array_ops.placeholder(dtypes.float32, (5, 1)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
|
||||
conditioning_utils.condition_tensor(
|
||||
array_ops.placeholder(dtypes.float32, (5, None)),
|
||||
array_ops.placeholder(dtypes.float32, (5, 1)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
|
||||
conditioning_utils.condition_tensor(
|
||||
array_ops.placeholder(dtypes.float32, (5, 2)),
|
||||
array_ops.placeholder(dtypes.float32, (5)))
|
||||
|
||||
def test_condition_tensor_from_onehot(self):
|
||||
conditioning_utils.condition_tensor_from_onehot(
|
||||
array_ops.placeholder(dtypes.float32, (5, 4, 1)),
|
||||
array_ops.placeholder(dtypes.float32, (5, 10)))
|
||||
|
||||
def test_condition_tensor_from_onehot_asserts(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 2'):
|
||||
conditioning_utils.condition_tensor_from_onehot(
|
||||
array_ops.placeholder(dtypes.float32, (5, 1)),
|
||||
array_ops.placeholder(dtypes.float32, (5)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
|
||||
conditioning_utils.condition_tensor_from_onehot(
|
||||
array_ops.placeholder(dtypes.float32, (5, 1)),
|
||||
array_ops.placeholder(dtypes.float32, (5, None)))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Cannot reshape a tensor'):
|
||||
conditioning_utils.condition_tensor_from_onehot(
|
||||
array_ops.placeholder(dtypes.float32, (5, 1)),
|
||||
array_ops.placeholder(dtypes.float32, (4, 6)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
306
tensorflow/contrib/gan/python/features/virtual_batchnorm.py
Normal file
306
tensorflow/contrib/gan/python/features/virtual_batchnorm.py
Normal file
@ -0,0 +1,306 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Virtual batch normalization.
|
||||
|
||||
This technique was first introduced in `Improved Techniques for Training GANs`
|
||||
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
|
||||
normalization on a minibatch, it fixes a reference subset of the data to use for
|
||||
calculating normalization statistics.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
__all__ = [
|
||||
'VBN',
|
||||
]
|
||||
|
||||
|
||||
def _static_or_dynamic_batch_size(tensor, batch_axis):
|
||||
"""Returns the static or dynamic batch size."""
|
||||
batch_size = array_ops.shape(tensor)[batch_axis]
|
||||
static_batch_size = tensor_util.constant_value(batch_size)
|
||||
return static_batch_size or batch_size
|
||||
|
||||
|
||||
def _statistics(x, axes):
|
||||
"""Calculate the mean and mean square of `x`.
|
||||
|
||||
Modified from the implementation of `tf.nn.moments`.
|
||||
|
||||
Args:
|
||||
x: A `Tensor`.
|
||||
axes: Array of ints. Axes along which to compute mean and
|
||||
variance.
|
||||
|
||||
Returns:
|
||||
Two `Tensor` objects: `mean` and `square mean`.
|
||||
"""
|
||||
# The dynamic range of fp16 is too limited to support the collection of
|
||||
# sufficient statistics. As a workaround we simply perform the operations
|
||||
# on 32-bit floats before converting the mean and variance back to fp16
|
||||
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||
|
||||
# Compute true mean while keeping the dims for proper broadcasting.
|
||||
shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
|
||||
|
||||
shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
|
||||
mean = shifted_mean + shift
|
||||
mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
|
||||
|
||||
mean = array_ops.squeeze(mean, axes)
|
||||
mean_squared = array_ops.squeeze(mean_squared, axes)
|
||||
if x.dtype == dtypes.float16:
|
||||
return (math_ops.cast(mean, dtypes.float16),
|
||||
math_ops.cast(mean_squared, dtypes.float16))
|
||||
else:
|
||||
return (mean, mean_squared)
|
||||
|
||||
|
||||
def _validate_init_input_and_get_axis(reference_batch, axis):
|
||||
"""Validate input and return the used axis value."""
|
||||
if reference_batch.shape.ndims is None:
|
||||
raise ValueError('`reference_batch` has unknown dimensions.')
|
||||
|
||||
ndims = reference_batch.shape.ndims
|
||||
if axis < 0:
|
||||
used_axis = ndims + axis
|
||||
else:
|
||||
used_axis = axis
|
||||
if used_axis < 0 or used_axis >= ndims:
|
||||
raise ValueError('Value of `axis` argument ' + str(used_axis) +
|
||||
' is out of range for input with rank ' + str(ndims))
|
||||
return used_axis
|
||||
|
||||
|
||||
def _validate_call_input(tensor_list, batch_dim):
|
||||
"""Verifies that tensor shapes are compatible, except for `batch_dim`."""
|
||||
def _get_shape(tensor):
|
||||
shape = tensor.shape.as_list()
|
||||
del shape[batch_dim]
|
||||
return shape
|
||||
base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
|
||||
for tensor in tensor_list:
|
||||
base_shape.assert_is_compatible_with(_get_shape(tensor))
|
||||
|
||||
|
||||
class VBN(object):
|
||||
"""A class to perform virtual batch normalization.
|
||||
|
||||
This technique was first introduced in `Improved Techniques for Training GANs`
|
||||
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
|
||||
normalization on a minibatch, it fixes a reference subset of the data to use
|
||||
for calculating normalization statistics.
|
||||
|
||||
To do this, we calculate the reference batch mean and mean square, and modify
|
||||
those statistics for each example. We use mean square instead of variance,
|
||||
since it is linear.
|
||||
|
||||
Note that if `center` or `scale` variables are created, they are shared
|
||||
between all calls to this object.
|
||||
|
||||
The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
|
||||
closely as possible.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reference_batch,
|
||||
axis=-1,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer=init_ops.zeros_initializer(),
|
||||
gamma_initializer=init_ops.ones_initializer(),
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
trainable=True,
|
||||
name=None,
|
||||
batch_axis=0):
|
||||
"""Initialize virtual batch normalization object.
|
||||
|
||||
We precompute the 'mean' and 'mean squared' of the reference batch, so that
|
||||
`__call__` is efficient. This means that the axis must be supplied when the
|
||||
object is created, not when it is called.
|
||||
|
||||
We precompute 'square mean' instead of 'variance', because the square mean
|
||||
can be easily adjusted on a per-example basis.
|
||||
|
||||
Args:
|
||||
reference_batch: A minibatch tensors. This will form the reference data
|
||||
from which the normalization statistics are calculated. See
|
||||
https://arxiv.org/abs/1606.03498 for more details.
|
||||
axis: Integer, the axis that should be normalized (typically the features
|
||||
axis). For instance, after a `Convolution2D` layer with
|
||||
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
center: If True, add offset of `beta` to normalized tensor. If False,
|
||||
`beta` is ignored.
|
||||
scale: If True, multiply by `gamma`. If False, `gamma` is
|
||||
not used. When the next layer is linear (also e.g. `nn.relu`), this can
|
||||
be disabled since the scaling can be done by the next layer.
|
||||
beta_initializer: Initializer for the beta weight.
|
||||
gamma_initializer: Initializer for the gamma weight.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
trainable: Boolean, if `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||
name: String, the name of the ops.
|
||||
batch_axis: The axis of the batch dimension. This dimension is treated
|
||||
differently in `virtual batch normalization` vs `batch normalization`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `reference_batch` has unknown dimensions at graph
|
||||
construction.
|
||||
ValueError: If `batch_axis` is the same as `axis`.
|
||||
"""
|
||||
axis = _validate_init_input_and_get_axis(reference_batch, axis)
|
||||
self._epsilon = epsilon
|
||||
self._beta = 0
|
||||
self._gamma = 1
|
||||
self._batch_axis = _validate_init_input_and_get_axis(
|
||||
reference_batch, batch_axis)
|
||||
|
||||
if axis == self._batch_axis:
|
||||
raise ValueError('`axis` and `batch_axis` cannot be the same.')
|
||||
|
||||
with variable_scope.variable_scope(name, 'VBN',
|
||||
values=[reference_batch]) as self._vs:
|
||||
self._reference_batch = reference_batch
|
||||
|
||||
# Calculate important shapes:
|
||||
# 1) Reduction axes for the reference batch
|
||||
# 2) Broadcast shape, if necessary
|
||||
# 3) Reduction axes for the virtual batchnormed batch
|
||||
# 4) Shape for optional parameters
|
||||
input_shape = self._reference_batch.shape
|
||||
ndims = input_shape.ndims
|
||||
reduction_axes = list(range(ndims))
|
||||
del reduction_axes[axis]
|
||||
|
||||
self._broadcast_shape = [1] * len(input_shape)
|
||||
self._broadcast_shape[axis] = input_shape[axis].value
|
||||
|
||||
self._example_reduction_axes = list(range(ndims))
|
||||
del self._example_reduction_axes[max(axis, self._batch_axis)]
|
||||
del self._example_reduction_axes[min(axis, self._batch_axis)]
|
||||
|
||||
params_shape = self._reference_batch.shape[axis]
|
||||
|
||||
# Determines whether broadcasting is needed. This is slightly different
|
||||
# than in the `nn.batch_normalization` case, due to `batch_dim`.
|
||||
self._needs_broadcasting = (
|
||||
sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
|
||||
|
||||
# Calculate the sufficient statistics for the reference batch in a way
|
||||
# that can be easily modified by additional examples.
|
||||
self._ref_mean, self._ref_mean_squares = _statistics(
|
||||
self._reference_batch, reduction_axes)
|
||||
self._ref_variance = (self._ref_mean_squares -
|
||||
math_ops.square(self._ref_mean))
|
||||
|
||||
# Virtual batch normalization uses a weighted average between example
|
||||
# statistics and the reference batch statistics.
|
||||
ref_batch_size = _static_or_dynamic_batch_size(
|
||||
self._reference_batch, self._batch_axis)
|
||||
self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
|
||||
self._ref_weight = 1. - self._example_weight
|
||||
|
||||
# Make the variables, if necessary.
|
||||
if center:
|
||||
self._beta = variable_scope.get_variable(
|
||||
name='beta',
|
||||
shape=(params_shape,),
|
||||
initializer=beta_initializer,
|
||||
regularizer=beta_regularizer,
|
||||
trainable=trainable)
|
||||
if scale:
|
||||
self._gamma = variable_scope.get_variable(
|
||||
name='gamma',
|
||||
shape=(params_shape,),
|
||||
initializer=gamma_initializer,
|
||||
regularizer=gamma_regularizer,
|
||||
trainable=trainable)
|
||||
|
||||
def _virtual_statistics(self, inputs, reduction_axes):
|
||||
"""Compute the statistics needed for virtual batch normalization."""
|
||||
cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
|
||||
vb_mean = (self._example_weight * cur_mean +
|
||||
self._ref_weight * self._ref_mean)
|
||||
vb_mean_sq = (self._example_weight * cur_mean_sq +
|
||||
self._ref_weight * self._ref_mean_squares)
|
||||
return (vb_mean, vb_mean_sq)
|
||||
|
||||
def _broadcast(self, v, broadcast_shape=None):
|
||||
# The exact broadcast shape depends on the current batch, not the reference
|
||||
# batch, unless we're calculating the batch normalization of the reference
|
||||
# batch.
|
||||
b_shape = broadcast_shape or self._broadcast_shape
|
||||
if self._needs_broadcasting and v is not None:
|
||||
return array_ops.reshape(v, b_shape)
|
||||
return v
|
||||
|
||||
def reference_batch_normalization(self):
|
||||
"""Return the reference batch, but batch normalized."""
|
||||
with ops.name_scope(self._vs.name):
|
||||
return nn.batch_normalization(self._reference_batch,
|
||||
self._broadcast(self._ref_mean),
|
||||
self._broadcast(self._ref_variance),
|
||||
self._broadcast(self._beta),
|
||||
self._broadcast(self._gamma),
|
||||
self._epsilon)
|
||||
|
||||
def __call__(self, inputs):
|
||||
"""Run virtual batch normalization on inputs.
|
||||
|
||||
Args:
|
||||
inputs: Tensor input.
|
||||
|
||||
Returns:
|
||||
A virtual batch normalized version of `inputs`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `inputs` shape isn't compatible with the reference batch.
|
||||
"""
|
||||
_validate_call_input([inputs, self._reference_batch], self._batch_axis)
|
||||
|
||||
with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
|
||||
# Calculate the statistics on the current input on a per-example basis.
|
||||
vb_mean, vb_mean_sq = self._virtual_statistics(
|
||||
inputs, self._example_reduction_axes)
|
||||
vb_variance = vb_mean_sq - math_ops.square(vb_mean)
|
||||
|
||||
# The exact broadcast shape of the input statistic Tensors depends on the
|
||||
# current batch, not the reference batch. The parameter broadcast shape
|
||||
# is independent of the shape of the input statistic Tensor dimensions.
|
||||
b_shape = self._broadcast_shape[:] # deep copy
|
||||
b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
|
||||
inputs, self._batch_axis)
|
||||
return nn.batch_normalization(
|
||||
inputs,
|
||||
self._broadcast(vb_mean, b_shape),
|
||||
self._broadcast(vb_variance, b_shape),
|
||||
self._broadcast(self._beta, self._broadcast_shape),
|
||||
self._broadcast(self._gamma, self._broadcast_shape),
|
||||
self._epsilon)
|
267
tensorflow/contrib/gan/python/features/virtual_batchnorm_test.py
Normal file
267
tensorflow/contrib/gan/python/features/virtual_batchnorm_test.py
Normal file
@ -0,0 +1,267 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tfgan.python.features.virtual_batchnorm."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
|
||||
from tensorflow.contrib.gan.python.features import virtual_batchnorm
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.layers import normalization
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class VirtualBatchnormTest(test.TestCase):
|
||||
|
||||
def test_syntax(self):
|
||||
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
|
||||
vbn = virtual_batchnorm.VBN(reference_batch, batch_axis=1)
|
||||
vbn(array_ops.ones([5, 7, 16, 9, 15]))
|
||||
|
||||
def test_no_broadcast_needed(self):
|
||||
"""When `axis` and `batch_axis` are at the end, no broadcast is needed."""
|
||||
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
|
||||
minibatch = array_ops.zeros([5, 3, 16, 3, 15])
|
||||
vbn = virtual_batchnorm.VBN(reference_batch, axis=-1, batch_axis=-2)
|
||||
vbn(minibatch)
|
||||
|
||||
def test_statistics(self):
|
||||
"""Check that `_statistics` gives the same result as `nn.moments`."""
|
||||
random_seed.set_random_seed(1234)
|
||||
|
||||
tensors = random_ops.random_normal([4, 5, 7, 3])
|
||||
for axes in [(3), (0, 2), (1, 2, 3)]:
|
||||
vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
|
||||
mom_mean, mom_var = nn.moments(tensors, axes)
|
||||
vb_var = mean_sq - math_ops.square(vb_mean)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
|
||||
vb_mean, vb_var, mom_mean, mom_var])
|
||||
|
||||
self.assertAllClose(mom_mean_np, vb_mean_np)
|
||||
self.assertAllClose(mom_var_np, vb_var_np)
|
||||
|
||||
def test_virtual_statistics(self):
|
||||
"""Check that `_virtual_statistics` gives same result as `nn.moments`."""
|
||||
random_seed.set_random_seed(1234)
|
||||
|
||||
batch_axis = 0
|
||||
partial_batch = random_ops.random_normal([4, 5, 7, 3])
|
||||
single_example = random_ops.random_normal([1, 5, 7, 3])
|
||||
full_batch = array_ops.concat([partial_batch, single_example], axis=0)
|
||||
|
||||
for reduction_axis in range(1, 4):
|
||||
# Get `nn.moments` on the full batch.
|
||||
reduction_axes = list(range(4))
|
||||
del reduction_axes[reduction_axis]
|
||||
mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)
|
||||
|
||||
# Get virtual batch statistics.
|
||||
vb_reduction_axes = list(range(4))
|
||||
del vb_reduction_axes[reduction_axis]
|
||||
del vb_reduction_axes[batch_axis]
|
||||
vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
|
||||
vb_mean, mean_sq = vbn._virtual_statistics(
|
||||
single_example, vb_reduction_axes)
|
||||
vb_variance = mean_sq - math_ops.square(vb_mean)
|
||||
# Remove singleton batch dim for easy comparisons.
|
||||
vb_mean = array_ops.squeeze(vb_mean, batch_axis)
|
||||
vb_variance = array_ops.squeeze(vb_variance, batch_axis)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
|
||||
vb_mean, vb_variance, mom_mean, mom_variance])
|
||||
|
||||
self.assertAllClose(mom_mean_np, vb_mean_np)
|
||||
self.assertAllClose(mom_var_np, vb_var_np)
|
||||
|
||||
def test_reference_batch_normalization(self):
|
||||
"""Check that batch norm from VBN agrees with opensource implementation."""
|
||||
random_seed.set_random_seed(1234)
|
||||
|
||||
batch = random_ops.random_normal([6, 5, 7, 3, 3])
|
||||
|
||||
for axis in range(5):
|
||||
# Get `layers` batchnorm result.
|
||||
bn_normalized = normalization.batch_normalization(
|
||||
batch, axis, training=True)
|
||||
|
||||
# Get VBN's batch normalization on reference batch.
|
||||
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same
|
||||
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
|
||||
vbn_normalized = vbn.reference_batch_normalization()
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
|
||||
bn_normalized_np, vbn_normalized_np = sess.run(
|
||||
[bn_normalized, vbn_normalized])
|
||||
self.assertAllClose(bn_normalized_np, vbn_normalized_np)
|
||||
|
||||
def test_same_as_batchnorm(self):
|
||||
"""Check that batch norm on set X is the same as ref of X / y on `y`."""
|
||||
random_seed.set_random_seed(1234)
|
||||
|
||||
num_examples = 4
|
||||
examples = [random_ops.random_normal([5, 7, 3]) for _ in
|
||||
range(num_examples)]
|
||||
|
||||
# Get the result of the opensource batch normalization.
|
||||
batch_normalized = normalization.batch_normalization(
|
||||
array_ops.stack(examples), training=True)
|
||||
|
||||
for i in range(num_examples):
|
||||
examples_except_i = array_ops.stack(examples[:i] + examples[i+1:])
|
||||
# Get the result of VBN's batch normalization.
|
||||
vbn = virtual_batchnorm.VBN(examples_except_i)
|
||||
vb_normed = array_ops.squeeze(
|
||||
vbn(array_ops.expand_dims(examples[i], [0])), [0])
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
bn_np, vb_np = sess.run([batch_normalized, vb_normed])
|
||||
self.assertAllClose(bn_np[i, ...], vb_np)
|
||||
|
||||
def test_minibatch_independent(self):
|
||||
"""Test that virtual batch normalized exampels are independent.
|
||||
|
||||
Unlike batch normalization, virtual batch normalization has the property
|
||||
that the virtual batch normalized value of an example is independent of the
|
||||
other examples in the minibatch. In this test, we verify this property.
|
||||
"""
|
||||
random_seed.set_random_seed(1234)
|
||||
|
||||
# These can be random, but must be the same for all session calls.
|
||||
reference_batch = constant_op.constant(
|
||||
np.random.normal(size=[4, 7, 3]), dtype=dtypes.float32)
|
||||
fixed_example = constant_op.constant(np.random.normal(size=[7, 3]),
|
||||
dtype=dtypes.float32)
|
||||
|
||||
# Get the VBN object and the virtual batch normalized value for
|
||||
# `fixed_example`.
|
||||
vbn = virtual_batchnorm.VBN(reference_batch)
|
||||
vbn_fixed_example = array_ops.squeeze(
|
||||
vbn(array_ops.expand_dims(fixed_example, 0)), 0)
|
||||
with self.test_session(use_gpu=True):
|
||||
variables_lib.global_variables_initializer().run()
|
||||
vbn_fixed_example_np = vbn_fixed_example.eval()
|
||||
|
||||
# Check that the value is the same for different minibatches, and different
|
||||
# sized minibatches.
|
||||
for minibatch_size in range(1, 6):
|
||||
examples = [random_ops.random_normal([7, 3]) for _ in
|
||||
range(minibatch_size)]
|
||||
|
||||
minibatch = array_ops.stack([fixed_example] + examples)
|
||||
vbn_minibatch = vbn(minibatch)
|
||||
cur_vbn_fixed_example = vbn_minibatch[0, ...]
|
||||
with self.test_session(use_gpu=True):
|
||||
variables_lib.global_variables_initializer().run()
|
||||
cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval()
|
||||
self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np)
|
||||
|
||||
def test_variable_reuse(self):
|
||||
"""Test that variable scopes work and inference on a real-ish case."""
|
||||
tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3])
|
||||
tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3])
|
||||
tensor2_ref = array_ops.zeros([4, 2, 3])
|
||||
tensor2_examples = array_ops.zeros([2, 2, 3])
|
||||
|
||||
with variable_scope.variable_scope('dummy_scope', reuse=True):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'does not exist, or was not created with '
|
||||
'tf.get_variable()'):
|
||||
virtual_batchnorm.VBN(tensor1_ref)
|
||||
|
||||
vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1')
|
||||
vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2')
|
||||
|
||||
# Fetch reference and examples after virtual batch normalization. Also
|
||||
# fetch in variable reuse case.
|
||||
to_fetch = []
|
||||
|
||||
to_fetch.append(vbn1.reference_batch_normalization())
|
||||
to_fetch.append(vbn2.reference_batch_normalization())
|
||||
to_fetch.append(vbn1(tensor1_examples))
|
||||
to_fetch.append(vbn2(tensor2_examples))
|
||||
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
|
||||
to_fetch.append(vbn1.reference_batch_normalization())
|
||||
to_fetch.append(vbn2.reference_batch_normalization())
|
||||
to_fetch.append(vbn1(tensor1_examples))
|
||||
to_fetch.append(vbn2(tensor2_examples))
|
||||
|
||||
self.assertEqual(4, len(contrib_variables_lib.get_variables()))
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
sess.run(to_fetch)
|
||||
|
||||
def test_invalid_input(self):
|
||||
# Reference batch has unknown dimensions.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, '`reference_batch` has unknown dimensions.'):
|
||||
virtual_batchnorm.VBN(array_ops.placeholder(dtypes.float32), name='vbn1')
|
||||
|
||||
# Axis too negative.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=-3, name='vbn2')
|
||||
|
||||
# Axis too large.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=2, name='vbn3')
|
||||
|
||||
# Batch axis too negative.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn4', batch_axis=-3)
|
||||
|
||||
# Batch axis too large.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn5', batch_axis=2)
|
||||
|
||||
# Axis and batch axis are the same.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, '`axis` and `batch_axis` cannot be the same.'):
|
||||
virtual_batchnorm.VBN(array_ops.zeros(
|
||||
[1, 2]), axis=1, name='vbn6', batch_axis=1)
|
||||
|
||||
# Reference Tensor and example Tensor have incompatible shapes.
|
||||
tensor_ref = array_ops.zeros([5, 2, 3])
|
||||
tensor_examples = array_ops.zeros([3, 2, 3])
|
||||
vbn = virtual_batchnorm.VBN(tensor_ref, name='vbn7', batch_axis=1)
|
||||
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
|
||||
vbn(tensor_examples)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user