Migrate instance_norm
to contrib/layers.
PiperOrigin-RevId: 169252675
This commit is contained in:
parent
5b94356280
commit
3588ff74d7
@ -58,6 +58,7 @@ tf_custom_op_py_library(
|
||||
"python/layers/feature_column_ops.py",
|
||||
"python/layers/initializers.py",
|
||||
"python/layers/layers.py",
|
||||
"python/layers/normalization.py",
|
||||
"python/layers/optimizers.py",
|
||||
"python/layers/regularizers.py",
|
||||
"python/layers/summaries.py",
|
||||
@ -176,6 +177,23 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "normalization_test",
|
||||
size = "small",
|
||||
srcs = ["python/layers/normalization_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":layers_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "optimizers_test",
|
||||
srcs = ["python/layers/optimizers_test.py"],
|
||||
|
@ -95,6 +95,8 @@ See the @{$python/contrib.layers} guide.
|
||||
@@weighted_sum_from_feature_columns
|
||||
@@infer_real_valued_columns
|
||||
@@sequence_input_from_feature_columns
|
||||
|
||||
@@instance_norm
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -112,6 +114,7 @@ _allowed_symbols = ['bias_add',
|
||||
'conv3d',
|
||||
'elu',
|
||||
'feature_column',
|
||||
'instance_norm',
|
||||
'legacy_fully_connected',
|
||||
'legacy_linear',
|
||||
'legacy_relu',
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib.layers.python.layers.feature_column import *
|
||||
from tensorflow.contrib.layers.python.layers.feature_column_ops import *
|
||||
from tensorflow.contrib.layers.python.layers.initializers import *
|
||||
from tensorflow.contrib.layers.python.layers.layers import *
|
||||
from tensorflow.contrib.layers.python.layers.normalization import *
|
||||
from tensorflow.contrib.layers.python.layers.optimizers import *
|
||||
from tensorflow.contrib.layers.python.layers.regularizers import *
|
||||
from tensorflow.contrib.layers.python.layers.summaries import *
|
||||
|
160
tensorflow/contrib/layers/python/layers/normalization.py
Normal file
160
tensorflow/contrib/layers/python/layers/normalization.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Contains the normalization layer classes and their functional aliases."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.layers.python.layers import utils
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
__all__ = [
|
||||
'instance_norm',
|
||||
]
|
||||
|
||||
DATA_FORMAT_NCHW = 'NCHW'
|
||||
DATA_FORMAT_NHWC = 'NHWC'
|
||||
|
||||
|
||||
@add_arg_scope
|
||||
def instance_norm(inputs,
|
||||
center=True,
|
||||
scale=True,
|
||||
epsilon=1e-6,
|
||||
activation_fn=None,
|
||||
param_initializers=None,
|
||||
reuse=None,
|
||||
variables_collections=None,
|
||||
outputs_collections=None,
|
||||
trainable=True,
|
||||
data_format=DATA_FORMAT_NHWC,
|
||||
scope=None):
|
||||
"""Functional interface for the instance normalization layer.
|
||||
|
||||
Reference: https://arxiv.org/abs/1607.08022.
|
||||
|
||||
"Instance Normalization: The Missing Ingredient for Fast Stylization"
|
||||
Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
|
||||
|
||||
Args:
|
||||
inputs: A tensor with 2 or more dimensions, where the first dimension has
|
||||
`batch_size`. The normalization is over all but the last dimension if
|
||||
`data_format` is `NHWC` and the second dimension if `data_format` is
|
||||
`NCHW`.
|
||||
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
|
||||
is ignored.
|
||||
scale: If True, multiply by `gamma`. If False, `gamma` is
|
||||
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
|
||||
disabled since the scaling can be done by the next layer.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
activation_fn: Activation function, default set to None to skip it and
|
||||
maintain a linear activation.
|
||||
param_initializers: Optional initializers for beta, gamma, moving mean and
|
||||
moving variance.
|
||||
reuse: Whether or not the layer and its variables should be reused. To be
|
||||
able to reuse the layer scope must be given.
|
||||
variables_collections: Optional collections for the variables.
|
||||
outputs_collections: Collections to add the outputs.
|
||||
trainable: If `True` also add variables to the graph collection
|
||||
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
|
||||
data_format: A string. `NHWC` (default) and `NCHW` are supported.
|
||||
scope: Optional scope for `variable_scope`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` representing the output of the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
|
||||
ValueError: If the rank of `inputs` is undefined.
|
||||
ValueError: If rank or channels dimension of `inputs` is undefined.
|
||||
"""
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
inputs_shape = inputs.shape
|
||||
inputs_rank = inputs.shape.ndims
|
||||
|
||||
if inputs_rank is None:
|
||||
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
|
||||
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
|
||||
raise ValueError('data_format has to be either NCHW or NHWC.')
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
|
||||
if data_format == DATA_FORMAT_NCHW:
|
||||
reduction_axis = 1
|
||||
# For NCHW format, rather than relying on implicit broadcasting, we
|
||||
# explicitly reshape the params to params_shape_broadcast when computing
|
||||
# the moments and the batch normalization.
|
||||
params_shape_broadcast = list(
|
||||
[1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
|
||||
else:
|
||||
reduction_axis = inputs_rank - 1
|
||||
params_shape_broadcast = None
|
||||
moments_axes = list(range(inputs_rank))
|
||||
del moments_axes[reduction_axis]
|
||||
del moments_axes[0]
|
||||
params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
|
||||
if not params_shape.is_fully_defined():
|
||||
raise ValueError('Inputs %s has undefined channels dimension %s.' % (
|
||||
inputs.name, params_shape))
|
||||
|
||||
# Allocate parameters for the beta and gamma of the normalization.
|
||||
beta, gamma = None, None
|
||||
dtype = inputs.dtype.base_dtype
|
||||
if param_initializers is None:
|
||||
param_initializers = {}
|
||||
if center:
|
||||
beta_collections = utils.get_variable_collections(
|
||||
variables_collections, 'beta')
|
||||
beta_initializer = param_initializers.get(
|
||||
'beta', init_ops.zeros_initializer())
|
||||
beta = variables.model_variable('beta',
|
||||
shape=params_shape,
|
||||
dtype=dtype,
|
||||
initializer=beta_initializer,
|
||||
collections=beta_collections,
|
||||
trainable=trainable)
|
||||
if params_shape_broadcast:
|
||||
beta = array_ops.reshape(beta, params_shape_broadcast)
|
||||
if scale:
|
||||
gamma_collections = utils.get_variable_collections(
|
||||
variables_collections, 'gamma')
|
||||
gamma_initializer = param_initializers.get(
|
||||
'gamma', init_ops.ones_initializer())
|
||||
gamma = variables.model_variable('gamma',
|
||||
shape=params_shape,
|
||||
dtype=dtype,
|
||||
initializer=gamma_initializer,
|
||||
collections=gamma_collections,
|
||||
trainable=trainable)
|
||||
if params_shape_broadcast:
|
||||
gamma = array_ops.reshape(gamma, params_shape_broadcast)
|
||||
|
||||
# Calculate the moments (instance activations).
|
||||
mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
|
||||
|
||||
# Compute instance normalization.
|
||||
outputs = nn.batch_normalization(
|
||||
inputs, mean, variance, beta, gamma, epsilon, name='instancenorm')
|
||||
if activation_fn is not None:
|
||||
outputs = activation_fn(outputs)
|
||||
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
|
170
tensorflow/contrib/layers/python/layers/normalization_test.py
Normal file
170
tensorflow/contrib/layers/python/layers/normalization_test.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Tests for contrib.layers.python.layers.normalization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.layers.python.layers import normalization
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InstanceNormTest(test.TestCase):
|
||||
|
||||
def testUnknownShape(self):
|
||||
inputs = array_ops.placeholder(dtypes.float32)
|
||||
with self.assertRaisesRegexp(ValueError, 'undefined rank'):
|
||||
normalization.instance_norm(inputs)
|
||||
|
||||
def testBadDataFormat(self):
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(2, 5, 5))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'data_format has to be either NCHW or NHWC.'):
|
||||
normalization.instance_norm(inputs, data_format='NHCW')
|
||||
|
||||
def testParamsShapeNotFullyDefinedNCHW(self):
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(3, None, 4))
|
||||
with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'):
|
||||
normalization.instance_norm(inputs, data_format='NCHW')
|
||||
|
||||
def testParamsShapeNotFullyDefinedNHWC(self):
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, None))
|
||||
with self.assertRaisesRegexp(ValueError, 'undefined channels dimension'):
|
||||
normalization.instance_norm(inputs, data_format='NHWC')
|
||||
|
||||
def testCreateOp(self):
|
||||
height, width = 3, 3
|
||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||
output = normalization.instance_norm(images)
|
||||
print('name: ', output.op.name)
|
||||
self.assertStartsWith(
|
||||
output.op.name, 'InstanceNorm/instancenorm')
|
||||
self.assertListEqual([5, height, width, 3], output.shape.as_list())
|
||||
|
||||
def testCreateOpFloat64(self):
|
||||
height, width = 3, 3
|
||||
images = random_ops.random_uniform(
|
||||
(5, height, width, 3), dtype=dtypes.float64, seed=1)
|
||||
output = normalization.instance_norm(images)
|
||||
self.assertStartsWith(
|
||||
output.op.name, 'InstanceNorm/instancenorm')
|
||||
self.assertListEqual([5, height, width, 3], output.shape.as_list())
|
||||
|
||||
def testCreateOpNoScaleCenter(self):
|
||||
height, width = 3, 3
|
||||
images = random_ops.random_uniform(
|
||||
(5, height, width, 3), dtype=dtypes.float64, seed=1)
|
||||
output = normalization.instance_norm(images, center=False, scale=False)
|
||||
self.assertStartsWith(
|
||||
output.op.name, 'InstanceNorm/instancenorm')
|
||||
self.assertListEqual([5, height, width, 3], output.shape.as_list())
|
||||
self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
|
||||
self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
|
||||
|
||||
def testCreateVariables(self):
|
||||
height, width = 3, 3
|
||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||
normalization.instance_norm(images, center=True, scale=True)
|
||||
beta = contrib_variables.get_variables_by_name('beta')[0]
|
||||
gamma = contrib_variables.get_variables_by_name('gamma')[0]
|
||||
self.assertEqual('InstanceNorm/beta', beta.op.name)
|
||||
self.assertEqual('InstanceNorm/gamma', gamma.op.name)
|
||||
|
||||
def testReuseVariables(self):
|
||||
height, width = 3, 3
|
||||
images = random_ops.random_uniform((5, height, width, 3), seed=1)
|
||||
normalization.instance_norm(images, scale=True, scope='IN')
|
||||
normalization.instance_norm(images, scale=True, scope='IN', reuse=True)
|
||||
beta = contrib_variables.get_variables_by_name('beta')
|
||||
gamma = contrib_variables.get_variables_by_name('gamma')
|
||||
self.assertEqual(1, len(beta))
|
||||
self.assertEqual(1, len(gamma))
|
||||
|
||||
def testValueCorrectWithReuseVars(self):
|
||||
height, width = 3, 3
|
||||
image_shape = (10, height, width, 3)
|
||||
images = random_ops.random_uniform(image_shape, seed=1)
|
||||
output_train = normalization.instance_norm(images, scope='IN')
|
||||
output_eval = normalization.instance_norm(images, scope='IN', reuse=True)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# output_train and output_eval should be the same.
|
||||
train_np, eval_np = sess.run([output_train, output_eval])
|
||||
self.assertAllClose(train_np, eval_np)
|
||||
|
||||
def doOutputTest(self, input_shape, data_format, tol=1e-3):
|
||||
axis = -1 if data_format == 'NHWC' else 1
|
||||
for mu in (0.0, 1e2):
|
||||
for sigma in (1.0, 0.1):
|
||||
# Determine shape of Tensor after normalization.
|
||||
reduced_shape = (input_shape[0], input_shape[axis])
|
||||
expected_mean = np.zeros(reduced_shape)
|
||||
expected_var = np.ones(reduced_shape)
|
||||
|
||||
# Determine axes that will be normalized.
|
||||
reduced_axes = list(range(len(input_shape)))
|
||||
del reduced_axes[axis]
|
||||
del reduced_axes[0]
|
||||
reduced_axes = tuple(reduced_axes)
|
||||
|
||||
inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
|
||||
output_op = normalization.instance_norm(
|
||||
inputs, center=False, scale=False, data_format=data_format)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
outputs = sess.run(output_op)
|
||||
# Make sure that there are no NaNs
|
||||
self.assertFalse(np.isnan(outputs).any())
|
||||
mean = np.mean(outputs, axis=reduced_axes)
|
||||
var = np.var(outputs, axis=reduced_axes)
|
||||
# The mean and variance of each example should be close to 0 and 1
|
||||
# respectively.
|
||||
self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
|
||||
self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
|
||||
|
||||
def testOutputSmallInput4DNHWC(self):
|
||||
self.doOutputTest((10, 10, 10, 30), 'NHWC', tol=1e-2)
|
||||
|
||||
def testOutputSmallInput4DNCHW(self):
|
||||
self.doOutputTest((10, 10, 10, 30), 'NCHW', tol=1e-2)
|
||||
|
||||
def testOutputBigInput4DNHWC(self):
|
||||
self.doOutputTest((1, 100, 100, 1), 'NHWC', tol=1e-3)
|
||||
|
||||
def testOutputBigInput4DNCHW(self):
|
||||
self.doOutputTest((1, 100, 100, 1), 'NCHW', tol=1e-3)
|
||||
|
||||
def testOutputSmallInput5DNHWC(self):
|
||||
self.doOutputTest((10, 10, 10, 10, 30), 'NHWC', tol=1e-2)
|
||||
|
||||
def testOutputSmallInput5DNCHW(self):
|
||||
self.doOutputTest((10, 10, 10, 10, 30), 'NCHW', tol=1e-2)
|
||||
|
||||
def testOutputBigInput5DNHWC(self):
|
||||
self.doOutputTest((1, 100, 100, 1, 1), 'NHWC', tol=1e-3)
|
||||
|
||||
def testOutputBigInput5DNCHW(self):
|
||||
self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user