Refactor out moment-testing and make more numpy friendly.
PiperOrigin-RevId: 230804842
This commit is contained in:
parent
2e17f8b3b7
commit
7c929bf6ef
@ -14,6 +14,14 @@ load("//tensorflow:tensorflow.bzl", "sycl_py_test")
|
||||
# Please avoid the py_tests and cuda_py_tests (plural) while we
|
||||
# fix the shared/overbroad dependencies.
|
||||
|
||||
py_library(
|
||||
name = "util",
|
||||
srcs = ["util.py"],
|
||||
deps = [
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "random_shuffle_queue_test",
|
||||
size = "small",
|
||||
@ -115,6 +123,7 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["random_gamma_test.py"],
|
||||
additional_deps = [
|
||||
":util",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -151,6 +160,7 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["random_poisson_test.py"],
|
||||
additional_deps = [
|
||||
":util",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -27,6 +25,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -69,16 +68,6 @@ class RandomGammaTest(test.TestCase):
|
||||
tf_logging.warn("Cannot test moments: %s" % e)
|
||||
return
|
||||
|
||||
# Check the given array of samples matches the given theoretical moment
|
||||
# function at different orders. The test is considered passing if the
|
||||
# z-tests of all statistical moments are all below z_limit.
|
||||
# Parameters:
|
||||
# max_moments: the largest moments of the distribution to be tested
|
||||
# stride: the distance between samples to check for statistical properties
|
||||
# 0 means the n-th moment of each sample
|
||||
# any other strides tests for spatial correlation between samples;
|
||||
# z_limit: the maximum z-test we would consider the test to pass;
|
||||
|
||||
# The moments test is a z-value test. This is the largest z-value
|
||||
# we want to tolerate. Since the z-test approximates a unit normal
|
||||
# distribution, it should almost definitely never exceed 6.
|
||||
@ -94,46 +83,13 @@ class RandomGammaTest(test.TestCase):
|
||||
max_moment = min(6, scale // 2)
|
||||
sampler = self._Sampler(
|
||||
20000, alpha, 1 / scale, dt, use_gpu=False, seed=12345)
|
||||
moments = [0] * (max_moment + 1)
|
||||
moments_sample_count = [0] * (max_moment + 1)
|
||||
x = np.array(sampler().flat) # sampler does 10x samples
|
||||
for k in range(len(x)):
|
||||
moment = 1.
|
||||
for i in range(max_moment + 1):
|
||||
index = k + i * stride
|
||||
if index >= len(x):
|
||||
break
|
||||
moments[i] += moment
|
||||
moments_sample_count[i] += 1
|
||||
moment *= x[index]
|
||||
for i in range(max_moment + 1):
|
||||
moments[i] /= moments_sample_count[i]
|
||||
for i in range(1, max_moment + 1):
|
||||
g = stats.gamma(alpha, scale=scale)
|
||||
if stride == 0:
|
||||
moments_i_mean = g.moment(i)
|
||||
moments_i_squared = g.moment(2 * i)
|
||||
else:
|
||||
moments_i_mean = pow(g.moment(1), i)
|
||||
moments_i_squared = pow(g.moment(2), i)
|
||||
# Calculate moment variance safely:
|
||||
# This is just
|
||||
# (moments_i_squared - moments_i_mean**2) / moments_sample_count[i]
|
||||
normalized_moments_i_var = (
|
||||
moments_i_mean / moments_sample_count[i] *
|
||||
(moments_i_squared / moments_i_mean - moments_i_mean))
|
||||
# Assume every operation has a small numerical error.
|
||||
# It takes i multiplications to calculate one i-th moment.
|
||||
error_per_moment = i * np.finfo(dt.as_numpy_dtype).eps
|
||||
total_variance = (normalized_moments_i_var + error_per_moment)
|
||||
tiny = np.finfo(dt.as_numpy_dtype).tiny
|
||||
self.assertGreaterEqual(total_variance, 0)
|
||||
if total_variance < tiny:
|
||||
total_variance = tiny
|
||||
# z_test is approximately a unit normal distribution.
|
||||
z_test = abs(
|
||||
(moments[i] - moments_i_mean) / math.sqrt(total_variance))
|
||||
self.assertLess(z_test, z_limit)
|
||||
z_scores = util.test_moment_matching(
|
||||
sampler(),
|
||||
max_moment,
|
||||
stats.gamma(alpha, scale=scale),
|
||||
stride=stride,
|
||||
)
|
||||
self.assertAllLess(z_scores, z_limit)
|
||||
|
||||
def _testZeroDensity(self, alpha):
|
||||
"""Zero isn't in the support of the gamma distribution.
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -49,14 +50,13 @@ class RandomPoissonTest(test.TestCase):
|
||||
|
||||
return func
|
||||
|
||||
# TODO(srvasude): Factor this out along with the corresponding moment testing
|
||||
# method in random_gamma_test into a single library.
|
||||
def testMoments(self):
|
||||
try:
|
||||
from scipy import stats # pylint: disable=g-import-not-at-top
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test moments: %s", e)
|
||||
return
|
||||
|
||||
# The moments test is a z-value test. This is the largest z-value
|
||||
# we want to tolerate. Since the z-test approximates a unit normal
|
||||
# distribution, it should almost definitely never exceed 6.
|
||||
@ -67,41 +67,13 @@ class RandomPoissonTest(test.TestCase):
|
||||
for lam in (3., 20):
|
||||
max_moment = 5
|
||||
sampler = self._Sampler(10000, lam, dt, use_gpu=False, seed=12345)
|
||||
moments = [0] * (max_moment + 1)
|
||||
moments_sample_count = [0] * (max_moment + 1)
|
||||
x = np.array(sampler().flat) # sampler does 10x samples
|
||||
for k in range(len(x)):
|
||||
moment = 1.
|
||||
for i in range(max_moment + 1):
|
||||
index = k + i * stride
|
||||
if index >= len(x):
|
||||
break
|
||||
moments[i] += moment
|
||||
moments_sample_count[i] += 1
|
||||
moment *= x[index]
|
||||
for i in range(max_moment + 1):
|
||||
moments[i] /= moments_sample_count[i]
|
||||
for i in range(1, max_moment + 1):
|
||||
g = stats.poisson(lam)
|
||||
if stride == 0:
|
||||
moments_i_mean = g.moment(i)
|
||||
moments_i_squared = g.moment(2 * i)
|
||||
else:
|
||||
moments_i_mean = pow(g.moment(1), i)
|
||||
moments_i_squared = pow(g.moment(2), i)
|
||||
moments_i_var = (
|
||||
moments_i_squared - moments_i_mean * moments_i_mean)
|
||||
# Assume every operation has a small numerical error.
|
||||
# It takes i multiplications to calculate one i-th moment.
|
||||
error_per_moment = i * 1e-6
|
||||
total_variance = (
|
||||
moments_i_var / moments_sample_count[i] + error_per_moment)
|
||||
if not total_variance:
|
||||
total_variance = 1e-10
|
||||
# z_test is approximately a unit normal distribution.
|
||||
z_test = abs(
|
||||
(moments[i] - moments_i_mean) / np.sqrt(total_variance))
|
||||
self.assertLess(z_test, z_limit)
|
||||
z_scores = util.test_moment_matching(
|
||||
sampler(),
|
||||
max_moment,
|
||||
stats.poisson(lam),
|
||||
stride=stride,
|
||||
)
|
||||
self.assertAllLess(z_scores, z_limit)
|
||||
|
||||
# Checks that the CPU and GPU implementation returns the same results,
|
||||
# given the same random seed
|
||||
|
72
tensorflow/python/kernel_tests/random/util.py
Normal file
72
tensorflow/python/kernel_tests/random/util.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for testing random variables."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_moment_matching(
|
||||
samples,
|
||||
number_moments,
|
||||
dist,
|
||||
stride=0):
|
||||
"""Return z-test scores for sample moments to match analytic moments.
|
||||
|
||||
Given `samples`, check that the first sample `number_moments` match
|
||||
the given `dist` moments by doing a z-test.
|
||||
|
||||
Args:
|
||||
samples: Samples from target distribution.
|
||||
number_moments: Python `int` describing how many sample moments to check.
|
||||
dist: SciPy distribution object that provides analytic moments.
|
||||
stride: Distance between samples to check for statistical properties.
|
||||
A stride of 0 means to use all samples, while other strides test for
|
||||
spatial correlation.
|
||||
Returns:
|
||||
Array of z_test scores.
|
||||
"""
|
||||
|
||||
sample_moments = []
|
||||
expected_moments = []
|
||||
variance_sample_moments = []
|
||||
x = samples.flat
|
||||
for i in range(1, number_moments + 1):
|
||||
strided_range = x[::(i - 1) * stride + 1]
|
||||
sample_moments.append(np.mean(strided_range ** i))
|
||||
expected_moments.append(dist.moment(i))
|
||||
variance_sample_moments.append(
|
||||
(dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range))
|
||||
|
||||
z_test_scores = []
|
||||
for i in range(1, number_moments + 1):
|
||||
# Assume every operation has a small numerical error.
|
||||
# It takes i multiplications to calculate one i-th moment.
|
||||
total_variance = (
|
||||
variance_sample_moments[i - 1] +
|
||||
i * np.finfo(samples.dtype).eps)
|
||||
tiny = np.finfo(samples.dtype).tiny
|
||||
assert np.all(total_variance > 0)
|
||||
if total_variance < tiny:
|
||||
total_variance = tiny
|
||||
# z_test is approximately a unit normal distribution.
|
||||
z_test_scores.append(abs(
|
||||
(sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt(
|
||||
total_variance)))
|
||||
return z_test_scores
|
||||
|
@ -82,6 +82,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/debug:debug_pip",
|
||||
"//tensorflow/python/eager:eager_pip",
|
||||
"//tensorflow/python/kernel_tests/random:util",
|
||||
"//tensorflow/python/kernel_tests/signal:test_util",
|
||||
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
|
Loading…
Reference in New Issue
Block a user