Migrate instance_norm to contrib/layers.

PiperOrigin-RevId: 169252675
This commit is contained in:
A. Unique TensorFlower 2017-09-19 09:33:23 -07:00 committed by TensorFlower Gardener
parent 5b94356280
commit 3588ff74d7
5 changed files with 352 additions and 0 deletions

View File

@ -58,6 +58,7 @@ tf_custom_op_py_library(
"python/layers/feature_column_ops.py",
"python/layers/initializers.py",
"python/layers/layers.py",
"python/layers/normalization.py",
"python/layers/optimizers.py",
"python/layers/regularizers.py",
"python/layers/summaries.py",
@ -176,6 +177,23 @@ py_test(
],
)
py_test(
name = "normalization_test",
size = "small",
srcs = ["python/layers/normalization_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "optimizers_test",
srcs = ["python/layers/optimizers_test.py"],

View File

@ -95,6 +95,8 @@ See the @{$python/contrib.layers} guide.
@@weighted_sum_from_feature_columns
@@infer_real_valued_columns
@@sequence_input_from_feature_columns
@@instance_norm
"""
from __future__ import absolute_import
@ -112,6 +114,7 @@ _allowed_symbols = ['bias_add',
'conv3d',
'elu',
'feature_column',
'instance_norm',
'legacy_fully_connected',
'legacy_linear',
'legacy_relu',

View File

@ -25,6 +25,7 @@ from tensorflow.contrib.layers.python.layers.feature_column import *
from tensorflow.contrib.layers.python.layers.feature_column_ops import *
from tensorflow.contrib.layers.python.layers.initializers import *
from tensorflow.contrib.layers.python.layers.layers import *
from tensorflow.contrib.layers.python.layers.normalization import *
from tensorflow.contrib.layers.python.layers.optimizers import *
from tensorflow.contrib.layers.python.layers.regularizers import *
from tensorflow.contrib.layers.python.layers.summaries import *

View File

@ -0,0 +1,160 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Contains the normalization layer classes and their functional aliases."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
__all__ = [
'instance_norm',
]
DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
@add_arg_scope
def instance_norm(inputs,
center=True,
scale=True,
epsilon=1e-6,
activation_fn=None,
param_initializers=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
data_format=DATA_FORMAT_NHWC,
scope=None):
"""Functional interface for the instance normalization layer.
Reference: https://arxiv.org/abs/1607.08022.
"Instance Normalization: The Missing Ingredient for Fast Stylization"
Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
Args:
inputs: A tensor with 2 or more dimensions, where the first dimension has
`batch_size`. The normalization is over all but the last dimension if
`data_format` is `NHWC` and the second dimension if `data_format` is
`NCHW`.
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling can be done by the next layer.
epsilon: Small float added to variance to avoid dividing by zero.
activation_fn: Activation function, default set to None to skip it and
maintain a linear activation.
param_initializers: Optional initializers for beta, gamma, moving mean and
moving variance.
reuse: Whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional collections for the variables.
outputs_collections: Collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
data_format: A string. `NHWC` (default) and `NCHW` are supported.
scope: Optional scope for `variable_scope`.
Returns:
A `Tensor` representing the output of the operation.
Raises:
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
ValueError: If the rank of `inputs` is undefined.
ValueError: If rank or channels dimension of `inputs` is undefined.
"""
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.shape
inputs_rank = inputs.shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
with variable_scope.variable_scope(
scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
if data_format == DATA_FORMAT_NCHW:
reduction_axis = 1
# For NCHW format, rather than relying on implicit broadcasting, we
# explicitly reshape the params to params_shape_broadcast when computing
# the moments and the batch normalization.
params_shape_broadcast = list(
[1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
else:
reduction_axis = inputs_rank - 1
params_shape_broadcast = None
moments_axes = list(range(inputs_rank))
del moments_axes[reduction_axis]
del moments_axes[0]
params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
if not params_shape.is_fully_defined():
raise ValueError('Inputs %s has undefined channels dimension %s.' % (
inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
dtype = inputs.dtype.base_dtype
if param_initializers is None:
param_initializers = {}
if center:
beta_collections = utils.get_variable_collections(
variables_collections, 'beta')
beta_initializer = param_initializers.get(
'beta', init_ops.zeros_initializer())
beta = variables.model_variable('beta',
shape=params_shape,
dtype=dtype,
initializer=beta_initializer,
collections=beta_collections,
trainable=trainable)
if params_shape_broadcast:
beta = array_ops.reshape(beta, params_shape_broadcast)
if scale:
gamma_collections = utils.get_variable_collections(
variables_collections, 'gamma')
gamma_initializer = param_initializers.get(
'gamma', init_ops.ones_initializer())
gamma = variables.model_variable('gamma',
shape=params_shape,
dtype=dtype,
initializer=gamma_initializer,
collections=gamma_collections,
trainable=trainable)
if params_shape_broadcast:
gamma = array_ops.reshape(gamma, params_shape_broadcast)
# Calculate the moments (instance activations).
mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
# Compute instance normalization.
outputs = nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon, name='instancenorm')
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)

View File

@ -0,0 +1,170 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for contrib.layers.python.layers.normalization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import normalization
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class InstanceNormTest(test.TestCase):
def testUnknownShape(self):
inputs = array_ops.placeholder(dtypes.float32)
with self.assertRaisesRegexp(ValueError, 'undefined rank'):
normalization.instance_norm(inputs)
def testBadDataFormat(self):
inputs = array_ops.placeholder(dtypes.float32, shape=(2, 5, 5))
with self.assertRaisesRegexp(ValueError,
'data_format has to be either NCHW or NHWC.'):
normalization.instance_norm(inputs, data_format='NHCW')
def testParamsShapeNotFullyDefinedNCHW(self):
inputs = array_ops.placeholder(dtypes.float32, shape=(3, None, 4))
with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'):
normalization.instance_norm(inputs, data_format='NCHW')
def testParamsShapeNotFullyDefinedNHWC(self):
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, None))
with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'):
normalization.instance_norm(inputs, data_format='NHWC')
def testCreateOp(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = normalization.instance_norm(images)
print('name: ', output.op.name)
self.assertStartsWith(
output.op.name, 'InstanceNorm/instancenorm')
self.assertListEqual([5, height, width, 3], output.shape.as_list())
def testCreateOpFloat64(self):
height, width = 3, 3
images = random_ops.random_uniform(
(5, height, width, 3), dtype=dtypes.float64, seed=1)
output = normalization.instance_norm(images)
self.assertStartsWith(
output.op.name, 'InstanceNorm/instancenorm')
self.assertListEqual([5, height, width, 3], output.shape.as_list())
def testCreateOpNoScaleCenter(self):
height, width = 3, 3
images = random_ops.random_uniform(
(5, height, width, 3), dtype=dtypes.float64, seed=1)
output = normalization.instance_norm(images, center=False, scale=False)
self.assertStartsWith(
output.op.name, 'InstanceNorm/instancenorm')
self.assertListEqual([5, height, width, 3], output.shape.as_list())
self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
def testCreateVariables(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
normalization.instance_norm(images, center=True, scale=True)
beta = contrib_variables.get_variables_by_name('beta')[0]
gamma = contrib_variables.get_variables_by_name('gamma')[0]
self.assertEqual('InstanceNorm/beta', beta.op.name)
self.assertEqual('InstanceNorm/gamma', gamma.op.name)
def testReuseVariables(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
normalization.instance_norm(images, scale=True, scope='IN')
normalization.instance_norm(images, scale=True, scope='IN', reuse=True)
beta = contrib_variables.get_variables_by_name('beta')
gamma = contrib_variables.get_variables_by_name('gamma')
self.assertEqual(1, len(beta))
self.assertEqual(1, len(gamma))
def testValueCorrectWithReuseVars(self):
height, width = 3, 3
image_shape = (10, height, width, 3)
images = random_ops.random_uniform(image_shape, seed=1)
output_train = normalization.instance_norm(images, scope='IN')
output_eval = normalization.instance_norm(images, scope='IN', reuse=True)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
self.assertAllClose(train_np, eval_np)
def doOutputTest(self, input_shape, data_format, tol=1e-3):
axis = -1 if data_format == 'NHWC' else 1
for mu in (0.0, 1e2):
for sigma in (1.0, 0.1):
# Determine shape of Tensor after normalization.
reduced_shape = (input_shape[0], input_shape[axis])
expected_mean = np.zeros(reduced_shape)
expected_var = np.ones(reduced_shape)
# Determine axes that will be normalized.
reduced_axes = list(range(len(input_shape)))
del reduced_axes[axis]
del reduced_axes[0]
reduced_axes = tuple(reduced_axes)
inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
output_op = normalization.instance_norm(
inputs, center=False, scale=False, data_format=data_format)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
self.assertFalse(np.isnan(outputs).any())
mean = np.mean(outputs, axis=reduced_axes)
var = np.var(outputs, axis=reduced_axes)
# The mean and variance of each example should be close to 0 and 1
# respectively.
self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
def testOutputSmallInput4DNHWC(self):
self.doOutputTest((10, 10, 10, 30), 'NHWC', tol=1e-2)
def testOutputSmallInput4DNCHW(self):
self.doOutputTest((10, 10, 10, 30), 'NCHW', tol=1e-2)
def testOutputBigInput4DNHWC(self):
self.doOutputTest((1, 100, 100, 1), 'NHWC', tol=1e-3)
def testOutputBigInput4DNCHW(self):
self.doOutputTest((1, 100, 100, 1), 'NCHW', tol=1e-3)
def testOutputSmallInput5DNHWC(self):
self.doOutputTest((10, 10, 10, 10, 30), 'NHWC', tol=1e-2)
def testOutputSmallInput5DNCHW(self):
self.doOutputTest((10, 10, 10, 10, 30), 'NCHW', tol=1e-2)
def testOutputBigInput5DNHWC(self):
self.doOutputTest((1, 100, 100, 1, 1), 'NHWC', tol=1e-3)
def testOutputBigInput5DNCHW(self):
self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3)
if __name__ == '__main__':
test.main()