657 lines
28 KiB
Python
657 lines
28 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Tests for batch_norm related functionality in tensorflow.ops.nn."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.python.ops import gen_nn_ops
|
|
|
|
|
|
class BatchNormalizationTest(tf.test.TestCase):
|
|
|
|
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
|
|
scale_after_normalization, shift_after_normalization):
|
|
y = (x - m) / np.sqrt(v + epsilon)
|
|
y = y * gamma if scale_after_normalization else y
|
|
return y + beta if shift_after_normalization else y
|
|
|
|
def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
|
|
scale_after_normalization, shift_after_normalization):
|
|
y = (x - m) * tf.rsqrt(v + epsilon)
|
|
if scale_after_normalization:
|
|
y = gamma * y
|
|
return y + beta if shift_after_normalization else y
|
|
|
|
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
|
|
scale_after_normalization):
|
|
"""Original implementation."""
|
|
# _batch_norm_with_global_normalization is deprecated in v9
|
|
tf.get_default_graph().graph_def_versions.producer = 8
|
|
# pylint: disable=protected-access
|
|
return gen_nn_ops._batch_norm_with_global_normalization(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
|
# pylint: enable=protected-access
|
|
|
|
def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon,
|
|
scale_after_normalization):
|
|
"""Re-implementation of the original kernel for backward compatibility."""
|
|
return tf.nn.batch_norm_with_global_normalization(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
|
|
|
def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
|
|
scale_after_normalization, shift_after_normalization):
|
|
"""New implementation."""
|
|
return tf.nn.batch_normalization(
|
|
x, m, v, beta if shift_after_normalization else None,
|
|
gamma if scale_after_normalization else None, epsilon)
|
|
|
|
def testBatchNorm(self):
|
|
x_shape = [3, 5, 4, 2]
|
|
param_shape = [2]
|
|
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
m_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
v_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
beta_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
for use_gpu in [True, False]:
|
|
with self.test_session(use_gpu=use_gpu) as sess:
|
|
x = tf.constant(x_val, name="x")
|
|
m = tf.constant(m_val, name="m")
|
|
v = tf.constant(v_val, name="v")
|
|
beta = tf.constant(beta_val, name="beta")
|
|
gamma = tf.constant(gamma_val, name="gamma")
|
|
epsilon = 0.001
|
|
for scale_after_normalization in [True, False]:
|
|
for shift_after_normalization in [True, False]:
|
|
bn2 = self._tfBatchNormV2(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
|
shift_after_normalization)
|
|
bn1bw = self._tfBatchNormV1BW(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
|
bn1 = self._tfBatchNormV1(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
|
on = self._opsBatchNorm(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
|
shift_after_normalization)
|
|
np_bn = self._npBatchNorm(
|
|
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
|
|
scale_after_normalization, shift_after_normalization)
|
|
tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
|
|
[bn2, bn1bw, bn1, on])
|
|
self.assertAllClose(np_bn, ops_bn, atol=0.00001)
|
|
self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001)
|
|
self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001)
|
|
# shift_after_normalization=False is not supported in v1.
|
|
if shift_after_normalization:
|
|
self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001)
|
|
self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001)
|
|
self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001)
|
|
self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001)
|
|
|
|
def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
|
|
shift_after_normalization, version,
|
|
err_tolerance=1e-11):
|
|
x_shape = [3, 5, 4, 5]
|
|
param_shape = [5]
|
|
np.random.seed(1) # Make it reproducible.
|
|
x_val = np.random.random_sample(x_shape).astype(np.float64)
|
|
m_val = np.random.random_sample(param_shape).astype(np.float64)
|
|
v_val = np.random.random_sample(param_shape).astype(np.float64)
|
|
beta_val = np.random.random_sample(param_shape).astype(np.float64)
|
|
gamma_val = np.random.random_sample(param_shape).astype(np.float64)
|
|
with self.test_session():
|
|
x = tf.constant(x_val, name="x")
|
|
m = tf.constant(m_val, name="m")
|
|
v = tf.constant(v_val, name="v")
|
|
beta = tf.constant(beta_val, name="beta")
|
|
gamma = tf.constant(gamma_val, name="gamma")
|
|
epsilon = 0.001
|
|
if version == 1:
|
|
output = self._tfBatchNormV1(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization)
|
|
elif version == 2:
|
|
output = self._tfBatchNormV2(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
|
shift_after_normalization)
|
|
else:
|
|
print("Invalid version", version)
|
|
raise ValueError()
|
|
all_params = [x, m, v, beta, gamma]
|
|
all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
|
|
err = tf.test.compute_gradient_error(
|
|
all_params[param_index], all_shapes[param_index], output, x_shape)
|
|
print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
|
|
(version, tag, "with" if scale_after_normalization else "without",
|
|
"with" if shift_after_normalization else "without"),
|
|
err)
|
|
self.assertLess(err, err_tolerance)
|
|
|
|
def _testBatchNormGradientInAllNeedConfigs(
|
|
self, param_index, tag, err_tolerance=1e-11):
|
|
for scale_after_normalization in [True, False]:
|
|
for shift_after_normalization in [True, False]:
|
|
# shift_after_normalization=False is not supported in version 1.
|
|
for v in ([1, 2] if shift_after_normalization else [2]):
|
|
self._testBatchNormGradient(
|
|
param_index, tag, scale_after_normalization,
|
|
shift_after_normalization, v, err_tolerance)
|
|
|
|
def testBatchNormInputGradient(self):
|
|
self._testBatchNormGradientInAllNeedConfigs(0, "x")
|
|
|
|
def testBatchNormMeanGradient(self):
|
|
self._testBatchNormGradientInAllNeedConfigs(1, "mean")
|
|
|
|
def testBatchNormVarianceGradient(self):
|
|
self._testBatchNormGradientInAllNeedConfigs(2, "variance",
|
|
err_tolerance=1e-03)
|
|
|
|
def testBatchNormBetaGradient(self):
|
|
# Since beta does not exist when scale_after_normalization=False, we only
|
|
# test for scale_after_normalization=True.
|
|
for scale_after_normalization in [True, False]:
|
|
for v in [1, 2]:
|
|
self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
|
|
v)
|
|
|
|
def testBatchNormGammaGradient(self):
|
|
# If scale_after_normalization is False, backprop for gamma in v1
|
|
# will be 0. In version 2 of the API, if scale_after_normalization is False,
|
|
# gamma is not used at all, and the gradient is None, which displeases the
|
|
# gradient checker.
|
|
for scale_after_normalization in [True, False]:
|
|
self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
|
|
1)
|
|
for shift_after_normalization in [True, False]:
|
|
self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
|
|
2)
|
|
|
|
def testBatchNormGradImpl(self):
|
|
x_shape = [7, 5, 4, 6]
|
|
param_shape = [6]
|
|
np.random.seed(1) # Make it reproducible.
|
|
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
m_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
v_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
beta_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
backprop_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
for use_gpu in [False, True]:
|
|
with self.test_session(use_gpu=use_gpu) as sess:
|
|
x = tf.constant(x_val, name="x")
|
|
m = tf.constant(m_val, name="m")
|
|
v = tf.constant(v_val, name="v")
|
|
beta = tf.constant(beta_val, name="beta")
|
|
gamma = tf.constant(gamma_val, name="gamma")
|
|
backprop = tf.constant(backprop_val, name="backprop")
|
|
epsilon = 0.001
|
|
for scale_after_normalization in [True, False]:
|
|
# _batch_norm_with_global_normalization_grad is deprecated in v9
|
|
tf.get_default_graph().graph_def_versions.producer = 8
|
|
grad = gen_nn_ops._batch_norm_with_global_normalization_grad(
|
|
x, m, v, gamma, backprop, epsilon, scale_after_normalization)
|
|
dx, dm, dv, db, dg = grad
|
|
self.assertEqual(grad.dx, dx)
|
|
self.assertEqual(grad.dm, dm)
|
|
self.assertEqual(grad.dv, dv)
|
|
self.assertEqual(grad.db, db)
|
|
self.assertEqual(grad.dg, dg)
|
|
|
|
on = self._opsBatchNorm(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization, True)
|
|
odx, odm, odv, odb, odg = tf.gradients(
|
|
[on], [x, m, v, beta, gamma], [backprop])
|
|
if scale_after_normalization:
|
|
all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
|
|
to_check = ["dx", "dm", "dv", "db", "dg"]
|
|
else:
|
|
all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
|
|
to_check = ["dx", "dm", "dv", "db"]
|
|
for i, _ in enumerate(to_check):
|
|
self.assertAllClose(
|
|
all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
|
|
|
|
def testBatchNormKeepDims(self):
|
|
"""Test for tf.nn.moments(..., keep_dims=True / False).
|
|
|
|
Make sure that parameters with shape (1, 1, 1, depth) yield the same
|
|
result as parameters with shape (depth)
|
|
"""
|
|
x_shape = (3, 5, 4, 2)
|
|
param_shape = (2)
|
|
keep_dims_param_shape = (1, 1, 1, 2)
|
|
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
m_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
v_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
beta_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
for use_gpu in [True, False]:
|
|
with self.test_session(use_gpu=use_gpu) as sess:
|
|
x = tf.constant(x_val, name="x")
|
|
m = tf.constant(m_val, name="m")
|
|
v = tf.constant(v_val, name="v")
|
|
beta = tf.constant(beta_val, name="beta")
|
|
gamma = tf.constant(gamma_val, name="gamma")
|
|
keep_dims_m = tf.reshape(m, keep_dims_param_shape, name="keep_dims_m")
|
|
keep_dims_v = tf.reshape(v, keep_dims_param_shape, name="keep_dims_v")
|
|
keep_dims_beta = tf.reshape(
|
|
beta, keep_dims_param_shape, name="keep_dims_beta")
|
|
keep_dims_gamma = tf.reshape(
|
|
gamma, keep_dims_param_shape, name="keep_dims_gamma")
|
|
epsilon = 0.001
|
|
for scale_after_normalization in [True, False]:
|
|
for shift_after_normalization in [True, False]:
|
|
bn = self._tfBatchNormV2(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
|
shift_after_normalization)
|
|
keep_dims_bn = self._tfBatchNormV2(
|
|
x, keep_dims_m, keep_dims_v, keep_dims_beta,
|
|
keep_dims_gamma, epsilon, scale_after_normalization,
|
|
shift_after_normalization)
|
|
tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
|
|
[bn, keep_dims_bn])
|
|
self.assertEquals(x_shape, tf_batch_norm.shape)
|
|
self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
|
|
self.assertAllClose(
|
|
tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
|
|
|
|
def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001):
|
|
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
m_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
v_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
beta_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
gamma_val = np.random.random_sample(param_shape).astype(np.float32)
|
|
for use_gpu in [True, False]:
|
|
with self.test_session(use_gpu=use_gpu) as sess:
|
|
x = tf.constant(x_val, name="x")
|
|
m = tf.constant(m_val, name="m")
|
|
v = tf.constant(v_val, name="v")
|
|
beta = tf.constant(beta_val, name="beta")
|
|
gamma = tf.constant(gamma_val, name="gamma")
|
|
epsilon = 0.001
|
|
for scale_after_normalization in [True, False]:
|
|
for shift_after_normalization in [True, False]:
|
|
bn = self._tfBatchNormV2(
|
|
x, m, v, beta, gamma, epsilon, scale_after_normalization,
|
|
shift_after_normalization)
|
|
np_batch_norm = self._npBatchNorm(
|
|
x_val, m_val, v_val, beta_val, gamma_val, epsilon,
|
|
scale_after_normalization, shift_after_normalization)
|
|
[tf_batch_norm] = sess.run([bn])
|
|
self.assertEquals(x_shape, np_batch_norm.shape)
|
|
self.assertEquals(x_shape, tf_batch_norm.shape)
|
|
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
|
|
|
|
def testBatchNormArbitraryShapes(self):
|
|
"""Test for a variety of shapes and moments.
|
|
|
|
Batch normalization is expected to work regardless of the position and
|
|
dimensionality of the 'depth' axis/axes.
|
|
"""
|
|
self._testBatchNormArbitraryShapes((3, 3), (1, 3))
|
|
self._testBatchNormArbitraryShapes((3, 3), (3, 1))
|
|
self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1))
|
|
self._testBatchNormArbitraryShapes((2, 3, 2, 4, 5), (1, 1, 1, 4, 5),
|
|
atol=0.005)
|
|
|
|
|
|
class SufficientStatisticsTest(tf.test.TestCase):
|
|
|
|
def _npSuffStats(self, x, axes, shift, keep_dims):
|
|
axis = tuple(axes)
|
|
if shift is not None:
|
|
m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
|
|
v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
|
|
else:
|
|
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
|
|
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
|
|
count = 1.0
|
|
for d in xrange(x.ndim):
|
|
if d in set(axes):
|
|
count *= x.shape[d]
|
|
if not keep_dims:
|
|
shift = np.squeeze(shift, axis=axis)
|
|
return count, m_ss, v_ss, shift
|
|
|
|
def _opSuffStats(self, x, axes, shift, keep_dims):
|
|
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
|
|
|
|
def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
|
|
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
|
|
for use_gpu in [True, False]:
|
|
with self.test_session(use_gpu=use_gpu) as sess:
|
|
if has_shape:
|
|
x = tf.constant(x_val, name="x")
|
|
x.set_shape(x_shape)
|
|
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
|
|
if shift:
|
|
tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s])
|
|
else:
|
|
tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v])
|
|
else:
|
|
x = tf.placeholder(dtype=tf.float32,
|
|
shape=[None] * len(x_shape),
|
|
name="x")
|
|
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
|
|
if shift:
|
|
tf_c, tf_m, tf_v, tf_s = sess.run(
|
|
[op_c, op_m, op_v, op_s],
|
|
feed_dict={x: x_val})
|
|
else:
|
|
tf_c, tf_m, tf_v = sess.run(
|
|
[op_c, op_m, op_v],
|
|
feed_dict={x: x_val})
|
|
self.assertAllClose(np_c, tf_c, atol=0.000001)
|
|
self.assertAllClose(np_m, tf_m, atol=0.000001)
|
|
self.assertAllClose(np_v, tf_v, atol=0.000001)
|
|
if shift:
|
|
self.assertAllClose(np_s, tf_s, atol=0.000001)
|
|
|
|
def testSuffStats(self):
|
|
for has_shape in [True, False]:
|
|
for keep_dims in [True, False]:
|
|
for shift in [None, 1.0]:
|
|
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
|
|
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
|
|
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
|
|
|
|
|
|
class NormalizeMomentsTest(tf.test.TestCase):
|
|
|
|
def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
|
|
mean = mean_ss / counts
|
|
variance = variance_ss / counts - mean * mean
|
|
if shift is not None:
|
|
mean += shift
|
|
return mean, variance
|
|
|
|
def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
|
|
return tf.nn.normalize_moments(counts, mean_ss, variance_ss, shift)
|
|
|
|
def _testNormalizeMoments(self, shape, shift):
|
|
counts = np.ones([1]).astype(np.float32)
|
|
mean_ss = np.random.random_sample(shape).astype(np.float32)
|
|
variance_ss = np.random.random_sample(shape).astype(np.float32)
|
|
variance_ss *= variance_ss
|
|
if shift:
|
|
shift_v = np.random.random_sample(shape).astype(np.float32)
|
|
else:
|
|
shift_v = None
|
|
npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
|
|
for use_gpu in [True, False]:
|
|
with self.test_session(use_gpu=use_gpu) as sess:
|
|
tf_counts = tf.constant(counts, name="counts")
|
|
tf_mean_ss = tf.constant(mean_ss, name="mean_ss")
|
|
tf_variance_ss = tf.constant(variance_ss, name="variance_ss")
|
|
if shift:
|
|
tf_shift_v = tf.constant(shift_v, name="shift")
|
|
else:
|
|
tf_shift_v = None
|
|
opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
|
|
tf_variance_ss, tf_shift_v)
|
|
tfm, tfv = sess.run([opm, opv])
|
|
self.assertAllClose(npm, tfm, atol=0.000001)
|
|
self.assertAllClose(npv, tfv, atol=0.000001)
|
|
|
|
def testNormalizeMoments(self):
|
|
for shift in [None, 4.0]:
|
|
self._testNormalizeMoments([3], shift)
|
|
self._testNormalizeMoments([2, 3], shift)
|
|
|
|
|
|
class MomentsTest(tf.test.TestCase):
|
|
|
|
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
|
|
# Method to compute moments of `x` wrt `axes`.
|
|
#
|
|
# This is exposed so WeightedMomentsTest can inherit the tests and
|
|
# assertions from MomentsTest; the extra_out_grads argument allows
|
|
# its inherited gradient tests to assert gradients against the
|
|
# weights as well as the input values.
|
|
|
|
return tf.nn.moments(x, axes, keep_dims=keep_dims)
|
|
|
|
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
|
|
with self.test_session():
|
|
# shape = [batch, width, height, depth]
|
|
assert len(shape) == 4
|
|
|
|
x_numpy = np.random.normal(size=shape).astype(np.float32)
|
|
x = tf.placeholder(dtype, shape=[None] * len(shape))
|
|
|
|
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
|
|
|
|
num_elements = np.prod([shape[i] for i in axes])
|
|
|
|
ax = tuple(axes)
|
|
expected_mean = np.sum(
|
|
x_numpy, axis=ax, keepdims=keep_dims) / num_elements
|
|
expected_mean_squared = np.multiply(expected_mean, expected_mean)
|
|
expected_x_squared = np.sum(
|
|
np.multiply(x_numpy, x_numpy),
|
|
axis=ax,
|
|
keepdims=keep_dims) / num_elements
|
|
expected_variance = expected_x_squared - expected_mean_squared
|
|
|
|
# Check that the moments are correct.
|
|
self.assertAllCloseAccordingToType(expected_mean,
|
|
mean.eval(feed_dict={x: x_numpy}))
|
|
self.assertAllCloseAccordingToType(expected_variance,
|
|
var.eval(feed_dict={x: x_numpy}))
|
|
|
|
def RunMomentTest(self, shape, axes, keep_dims, dtype):
|
|
with self.test_session():
|
|
# shape = [batch, width, height, depth]
|
|
assert len(shape) == 4
|
|
|
|
x_numpy = np.random.normal(size=shape).astype(np.float32)
|
|
x = tf.cast(tf.constant(x_numpy), dtype=dtype)
|
|
|
|
# Compute the expected values at high precision since the method
|
|
# is prone to catastrophic cancellation:
|
|
x_numpy = x_numpy.astype(np.float128)
|
|
|
|
mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
|
|
|
|
num_elements = np.prod([shape[i] for i in axes])
|
|
|
|
ax = tuple(axes)
|
|
expected_mean = np.sum(
|
|
x_numpy, axis=ax, keepdims=keep_dims) / num_elements
|
|
expected_mean_squared = np.multiply(expected_mean, expected_mean)
|
|
expected_x_squared = np.sum(
|
|
np.multiply(x_numpy, x_numpy),
|
|
axis=ax,
|
|
keepdims=keep_dims) / num_elements
|
|
expected_variance = expected_x_squared - expected_mean_squared
|
|
|
|
# Check that the moments are correct.
|
|
self.assertAllCloseAccordingToType(expected_mean, mean.eval())
|
|
self.assertAllCloseAccordingToType(expected_variance, var.eval())
|
|
|
|
def testBasic(self):
|
|
for keep_dims in [False, True]:
|
|
for dtype in [tf.float32, tf.float16]:
|
|
self.RunMomentTest(shape=[2, 3, 5, 4],
|
|
axes=[0],
|
|
keep_dims=keep_dims,
|
|
dtype=dtype)
|
|
self.RunMomentTestWithDynamicShape(shape=[2, 3, 5, 4],
|
|
axes=[0],
|
|
keep_dims=keep_dims,
|
|
dtype=dtype)
|
|
|
|
def testGlobalNormalization(self):
|
|
for keep_dims in [False, True]:
|
|
for dtype in [tf.float32, tf.float16]:
|
|
self.RunMomentTest(shape=[2, 3, 5, 4],
|
|
axes=[0, 1, 2],
|
|
keep_dims=keep_dims,
|
|
dtype=dtype)
|
|
self.RunMomentTestWithDynamicShape(shape=[2, 3, 5, 4],
|
|
axes=[0, 1, 2],
|
|
keep_dims=keep_dims,
|
|
dtype=dtype)
|
|
|
|
def testAxes(self):
|
|
for keep_dims in [False, True]:
|
|
for dtype in [tf.float32, tf.float16]:
|
|
self.RunMomentTest(shape=[2, 3, 5, 4],
|
|
axes=[1, 2, 3],
|
|
keep_dims=keep_dims,
|
|
dtype=dtype)
|
|
self.RunMomentTestWithDynamicShape(shape=[2, 3, 5, 4],
|
|
axes=[1, 2, 3],
|
|
keep_dims=keep_dims,
|
|
dtype=dtype)
|
|
|
|
def _testGlobalGradient(self, from_y="mean"):
|
|
with self.test_session():
|
|
x_shape = [3, 5, 4, 2]
|
|
x_val = np.random.random_sample(x_shape).astype(np.float64)
|
|
x = tf.constant(x_val)
|
|
x.set_shape(x_shape)
|
|
|
|
axes = [0, 1, 2]
|
|
y_shape = [2] # Depth of x
|
|
|
|
inputs_to_compute_gradients_for = [x]
|
|
|
|
out_mean, out_var = self._unweighted_moments(
|
|
x, axes, extra_out_grads=inputs_to_compute_gradients_for)
|
|
if from_y == "mean":
|
|
y = out_mean
|
|
elif from_y == "var":
|
|
y = out_var
|
|
|
|
for (i, v) in enumerate(inputs_to_compute_gradients_for):
|
|
err = tf.test.compute_gradient_error(v, v.get_shape().as_list(),
|
|
y, y_shape)
|
|
print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
|
|
self.assertLess(err, 1e-11)
|
|
|
|
def testMeanGlobalGradient(self):
|
|
self._testGlobalGradient(from_y="mean")
|
|
|
|
def testVarGlobalGradient(self):
|
|
self._testGlobalGradient(from_y="var")
|
|
|
|
|
|
class WeightedMomentsTest(MomentsTest):
|
|
"""Tests for nn.weighted_moments.
|
|
|
|
Note that this test inherits from MomentsTest, inheriting all its
|
|
test methods!
|
|
|
|
It modifies MomentsTest in two ways:
|
|
|
|
a) By overriding _unweighted_moments, all the codepaths in
|
|
MomentsTest are executed, but with calls to tf.nn.moments()
|
|
replaced by calls to tf.nn.weighted_moments() with a constant
|
|
weight of 1.
|
|
|
|
b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
|
|
this test adds multiple additional calls to
|
|
RunWeightedMomentsTest() to exercise correctness with
|
|
non-constant weights and varying broadcasting situations. (It
|
|
also continues to call MomentsTest.Run(Weighted)?MomentsTest as
|
|
well.)
|
|
|
|
"""
|
|
|
|
def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
|
|
weights = tf.constant(1, dtype=x.dtype)
|
|
if extra_out_grads is not None:
|
|
# We want to assert gradients WRT weights as well as X!
|
|
extra_out_grads.append(weights)
|
|
return tf.nn.weighted_moments(
|
|
x, axes, weights, keep_dims=keep_dims)
|
|
|
|
def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
|
|
if not dynshapes:
|
|
super(WeightedMomentsTest, self).RunMomentTest(
|
|
shape, axes, keep_dims, dtype)
|
|
else:
|
|
super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(
|
|
shape, axes, keep_dims, dtype)
|
|
|
|
# 1:1 weights and inputs
|
|
self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
|
|
|
|
# Various broadcasting combinations
|
|
for idx in range(len(shape)):
|
|
# try broadcasting weights in all positions
|
|
weight_shape = [1] * len(shape)
|
|
weight_shape[idx] = shape[idx]
|
|
|
|
self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
|
|
|
|
# Also try broadcasting with a suffix of length n
|
|
weight_shape = shape[-(idx+1):]
|
|
self.RunWeightedMomentTest(
|
|
shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
|
|
|
|
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
|
|
self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
|
|
|
|
def RunWeightedMomentTest(
|
|
self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False):
|
|
with self.test_session() as s:
|
|
x_numpy = np.random.normal(size=shape).astype(np.float32)
|
|
weights_numpy = np.absolute( # weights must be positive
|
|
np.random.normal(size=weights_shape, loc=1.0).astype(np.float32))
|
|
|
|
# Expand the numpy version to higher precision
|
|
x_numpy = x_numpy.astype(np.float128)
|
|
weights_numpy = weights_numpy.astype(np.float128)
|
|
|
|
x_shape = [None] * len(shape) if dynshapes else shape
|
|
weights_shape = (
|
|
[None] * len(weights_shape) if dynshapes else weights_shape)
|
|
|
|
x = tf.placeholder(dtype, shape=x_shape)
|
|
weights = tf.placeholder(dtype, shape=weights_shape)
|
|
|
|
mean, var = tf.nn.weighted_moments(x, axes, weights, keep_dims=keep_dims)
|
|
|
|
ax = tuple(axes)
|
|
|
|
def _np_weighted_sum(v):
|
|
return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
|
|
|
|
weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
|
|
expected_mean = _np_weighted_sum(x_numpy) / weight_sum
|
|
expected_mean_squared = np.multiply(expected_mean, expected_mean)
|
|
expected_x_squared = (
|
|
_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum)
|
|
expected_variance = expected_x_squared - expected_mean_squared
|
|
|
|
mean_v, var_v = s.run([mean, var],
|
|
feed_dict={x: x_numpy, weights: weights_numpy})
|
|
|
|
self.assertAllCloseAccordingToType(expected_mean, mean_v)
|
|
self.assertAllCloseAccordingToType(expected_variance, var_v)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tf.test.main()
|