Fixed compiler/tests:stateful_random_ops_test for f64, made the test more parallel and faster, and removed @run_v2_only
so the test can be picked up by TAP.
PiperOrigin-RevId: 256223403
This commit is contained in:
parent
840f25bd46
commit
096f7e3906
@ -942,8 +942,9 @@ tf_xla_py_test(
|
|||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "stateful_random_ops_test",
|
name = "stateful_random_ops_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["stateful_random_ops_test.py"],
|
srcs = ["stateful_random_ops_test.py"],
|
||||||
|
shard_count = 10,
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
@ -957,7 +958,7 @@ tf_xla_py_test(
|
|||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "stateless_random_ops_test",
|
name = "stateless_random_ops_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["stateless_random_ops_test.py"],
|
srcs = ["stateless_random_ops_test.py"],
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -27,7 +29,6 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.kernel_tests.random import util as \
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
random_test_util
|
random_test_util
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
@ -37,33 +38,33 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def xla_device_name():
|
def xla_device():
|
||||||
devices = device_lib.list_local_devices()
|
devices = device_lib.list_local_devices()
|
||||||
def find_type(device_type):
|
def find_type(device_type):
|
||||||
for d in devices:
|
for d in devices:
|
||||||
if d.device_type == device_type:
|
if d.device_type == device_type:
|
||||||
return d.name
|
return d
|
||||||
return None
|
return None
|
||||||
name = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU")
|
d = find_type("TPU") or find_type("XLA_GPU") or find_type("XLA_CPU")
|
||||||
if name is None:
|
if d is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find any XLA device. Available devices:\n%s" % devices)
|
"Can't find any XLA device. Available devices:\n%s" % devices)
|
||||||
return str(name)
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def xla_device_name():
|
||||||
|
return str(xla_device().name)
|
||||||
|
|
||||||
|
|
||||||
ALGS = [random.RNG_ALG_PHILOX, random.RNG_ALG_THREEFRY]
|
ALGS = [random.RNG_ALG_PHILOX, random.RNG_ALG_THREEFRY]
|
||||||
INTS = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64]
|
INTS = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64]
|
||||||
|
FLOATS = [dtypes.bfloat16, dtypes.float32, dtypes.float64]
|
||||||
|
|
||||||
|
|
||||||
# TODO(wangpeng): use parametrized tests to test both ThreeFry and Philox
|
|
||||||
class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
"""Test cases for stateful random-number generator operators."""
|
"""Test cases for stateful random-number generator operators."""
|
||||||
|
|
||||||
_ints = INTS
|
|
||||||
_floats = [dtypes.bfloat16, dtypes.float32, dtypes.float64]
|
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(ALGS)
|
||||||
@test_util.run_v2_only
|
|
||||||
def testSimple(self, alg):
|
def testSimple(self, alg):
|
||||||
"""A simple test."""
|
"""A simple test."""
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
@ -73,7 +74,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
gen.uniform_full_int(shape=(3,))
|
gen.uniform_full_int(shape=(3,))
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(ALGS)
|
||||||
@test_util.run_v2_only
|
|
||||||
def testDefun(self, alg):
|
def testDefun(self, alg):
|
||||||
"""Test for defun."""
|
"""Test for defun."""
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
@ -106,7 +106,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
got = g.uniform_full_int(shape=(ctr_len // 2,), dtype=dtypes.uint64)
|
got = g.uniform_full_int(shape=(ctr_len // 2,), dtype=dtypes.uint64)
|
||||||
self.assertAllEqual(uint32s_to_uint64s(expect), got)
|
self.assertAllEqual(uint32s_to_uint64s(expect), got)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testThreefry2x32(self):
|
def testThreefry2x32(self):
|
||||||
"""Tests ThreeFry2x32 conforms to known results.
|
"""Tests ThreeFry2x32 conforms to known results.
|
||||||
"""
|
"""
|
||||||
@ -130,7 +129,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
[0x243f6a88, 0x85a308d3], [0x13198a2e, 0x03707344],
|
[0x243f6a88, 0x85a308d3], [0x13198a2e, 0x03707344],
|
||||||
[0xc4923a9c, 0x483df7a0])
|
[0xc4923a9c, 0x483df7a0])
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testPhilox4x32(self):
|
def testPhilox4x32(self):
|
||||||
"""Tests Philox4x32 conforms to known results.
|
"""Tests Philox4x32 conforms to known results.
|
||||||
"""
|
"""
|
||||||
@ -155,7 +153,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
[0xa4093822, 0x299f31d0],
|
[0xa4093822, 0x299f31d0],
|
||||||
[0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1])
|
[0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1])
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testNewStateThreeFry(self):
|
def testNewStateThreeFry(self):
|
||||||
"""Tests that the new state is correct (for ThreeFry).
|
"""Tests that the new state is correct (for ThreeFry).
|
||||||
"""
|
"""
|
||||||
@ -171,7 +168,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
||||||
self.assertAllEqual([counter+size, key], gen.state.read_value())
|
self.assertAllEqual([counter+size, key], gen.state.read_value())
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testNewStatePhilox(self):
|
def testNewStatePhilox(self):
|
||||||
"""Tests that the new state is correct (for Philox).
|
"""Tests that the new state is correct (for Philox).
|
||||||
"""
|
"""
|
||||||
@ -204,7 +200,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
gen.state.read_value())
|
gen.state.read_value())
|
||||||
|
|
||||||
@parameterized.parameters(INTS)
|
@parameterized.parameters(INTS)
|
||||||
@test_util.run_v2_only
|
|
||||||
def testXLAEqualsCPU(self, dtype):
|
def testXLAEqualsCPU(self, dtype):
|
||||||
"""Tests that XLA and CPU kernels generate the same integers."""
|
"""Tests that XLA and CPU kernels generate the same integers."""
|
||||||
seed = 1234
|
seed = 1234
|
||||||
@ -225,63 +220,60 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
y = rng(dtype).numpy()
|
y = rng(dtype).numpy()
|
||||||
self.assertFalse(np.array_equal(x, y))
|
self.assertFalse(np.array_equal(x, y))
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
def check_dtype(self, dtype):
|
||||||
@test_util.run_v2_only
|
device = xla_device()
|
||||||
def testUniformIsNotConstant(self, alg):
|
if device.device_type == "TPU" and dtype == dtypes.float64:
|
||||||
|
self.skipTest("TPU doesn't support float64.")
|
||||||
|
|
||||||
|
@parameterized.parameters(list(itertools.product(ALGS, INTS + FLOATS)))
|
||||||
|
def testUniformIsNotConstant(self, alg, dtype):
|
||||||
|
self.check_dtype(dtype)
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||||
def rng(dtype):
|
def rng(dtype):
|
||||||
maxval = dtype.max
|
maxval = dtype.max
|
||||||
# Workaround for b/125364959
|
|
||||||
if dtype == dtypes.uint64:
|
|
||||||
maxval = 10000000
|
|
||||||
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
|
return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)
|
||||||
|
|
||||||
for dtype in self._ints + self._floats:
|
|
||||||
self._testRngIsNotConstant(rng, dtype)
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(list(itertools.product(ALGS, FLOATS)))
|
||||||
@test_util.run_v2_only
|
def testNormalIsNotConstant(self, alg, dtype):
|
||||||
def testNormalIsNotConstant(self, alg):
|
self.check_dtype(dtype)
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||||
def rng(dtype):
|
def rng(dtype):
|
||||||
return gen.normal(shape=[2], dtype=dtype)
|
return gen.normal(shape=[2], dtype=dtype)
|
||||||
|
|
||||||
for dtype in self._floats:
|
|
||||||
self._testRngIsNotConstant(rng, dtype)
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(list(itertools.product(ALGS, INTS + FLOATS)))
|
||||||
@test_util.run_v2_only
|
def testUniformIsInRange(self, alg, dtype):
|
||||||
def testUniformIsInRange(self, alg):
|
self.check_dtype(dtype)
|
||||||
minval = 2
|
minval = 2
|
||||||
maxval = 33
|
maxval = 33
|
||||||
size = 1000
|
size = 1000
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
for dtype in self._ints + self._floats:
|
|
||||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||||
x = gen.uniform(
|
x = gen.uniform(
|
||||||
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||||
self.assertTrue(np.all(x >= minval))
|
self.assertTrue(np.all(x >= minval))
|
||||||
self.assertTrue(np.all(x <= maxval))
|
self.assertTrue(np.all(x <= maxval))
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(list(itertools.product(ALGS, FLOATS)))
|
||||||
@test_util.run_v2_only
|
def testNormalIsFinite(self, alg, dtype):
|
||||||
def testNormalIsFinite(self, alg):
|
self.check_dtype(dtype)
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||||
for dtype in self._floats:
|
|
||||||
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||||
self.assertTrue(np.all(np.isfinite(x)))
|
self.assertTrue(np.all(np.isfinite(x)))
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(list(itertools.product(ALGS, INTS + FLOATS)))
|
||||||
@test_util.run_v2_only
|
def testDistributionOfUniform(self, alg, dtype):
|
||||||
def testDistributionOfUniform(self, alg):
|
|
||||||
"""Use Pearson's Chi-squared test to test for uniformity."""
|
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||||
|
self.check_dtype(dtype)
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
n = 1000
|
n = 1000
|
||||||
seed = 12
|
seed = 12
|
||||||
for dtype in self._ints + self._floats:
|
|
||||||
gen = random.Generator.from_seed(seed=seed, alg=alg)
|
gen = random.Generator.from_seed(seed=seed, alg=alg)
|
||||||
maxval = 1
|
maxval = 1
|
||||||
if dtype.is_integer:
|
if dtype.is_integer:
|
||||||
@ -297,13 +289,12 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
val = random_test_util.chi_squared(x, 10)
|
val = random_test_util.chi_squared(x, 10)
|
||||||
self.assertLess(val, 16.92)
|
self.assertLess(val, 16.92)
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(list(itertools.product(ALGS, FLOATS)))
|
||||||
@test_util.run_v2_only
|
def testDistributionOfNormal(self, alg, dtype):
|
||||||
def testDistributionOfNormal(self, alg):
|
|
||||||
"""Use Anderson-Darling test to test distribution appears normal."""
|
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||||
|
self.check_dtype(dtype)
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
n = 1000
|
n = 1000
|
||||||
for dtype in self._floats:
|
|
||||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||||
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||||
@ -312,18 +303,18 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
self.assertLess(
|
self.assertLess(
|
||||||
random_test_util.anderson_darling(x.astype(float)), 2.492)
|
random_test_util.anderson_darling(x.astype(float)), 2.492)
|
||||||
|
|
||||||
@parameterized.parameters(ALGS)
|
@parameterized.parameters(list(itertools.product(ALGS, FLOATS)))
|
||||||
@test_util.run_v2_only
|
def testTruncatedNormal(self, alg, dtype):
|
||||||
def testTruncatedNormal(self, alg):
|
self.check_dtype(dtype)
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
for dtype in self._floats:
|
|
||||||
gen = random.Generator.from_seed(seed=123, alg=alg)
|
gen = random.Generator.from_seed(seed=123, alg=alg)
|
||||||
n = 10000000
|
n = 100000
|
||||||
y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
|
y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
|
||||||
random_test_util.test_truncated_normal(
|
random_test_util.test_truncated_normal(
|
||||||
self.assertEqual, self.assertAllClose, dtype, n, y)
|
self.assertEqual, self.assertAllClose, n, y,
|
||||||
|
mean_atol=2e-3, median_atol=4e-3,
|
||||||
|
variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def testErrors(self):
|
def testErrors(self):
|
||||||
"""Tests that proper errors are raised.
|
"""Tests that proper errors are raised.
|
||||||
"""
|
"""
|
||||||
@ -371,4 +362,5 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
ops.enable_eager_execution()
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -128,7 +128,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
shape=[n], seed=seed_t, dtype=dtype)
|
shape=[n], seed=seed_t, dtype=dtype)
|
||||||
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
||||||
random_test_util.test_truncated_normal(
|
random_test_util.test_truncated_normal(
|
||||||
self.assertEqual, self.assertAllClose, dtype, n, y)
|
self.assertEqual, self.assertAllClose, n, y,
|
||||||
|
variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -22,7 +22,6 @@ import math
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.ops.distributions import special_math
|
from tensorflow.python.ops.distributions import special_math
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +99,8 @@ def anderson_darling(x):
|
|||||||
return -n - z / n
|
return -n - z / n
|
||||||
|
|
||||||
|
|
||||||
def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
|
def test_truncated_normal(assert_equal, assert_all_close, n, y,
|
||||||
|
mean_atol=5e-4, median_atol=8e-4, variance_rtol=1e-3):
|
||||||
"""Tests truncated normal distribution's statistics."""
|
"""Tests truncated normal distribution's statistics."""
|
||||||
def _normal_cdf(x):
|
def _normal_cdf(x):
|
||||||
return .5 * math.erfc(-x / math.sqrt(2))
|
return .5 * math.erfc(-x / math.sqrt(2))
|
||||||
@ -129,12 +129,12 @@ def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
|
|||||||
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
|
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
|
||||||
y = y.astype(float)
|
y = y.astype(float)
|
||||||
actual_mean = np.mean(y)
|
actual_mean = np.mean(y)
|
||||||
assert_all_close(actual_mean, expected_mean, atol=5e-4)
|
assert_all_close(actual_mean, expected_mean, atol=mean_atol)
|
||||||
|
|
||||||
expected_median = mu + probit(
|
expected_median = mu + probit(
|
||||||
(_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma
|
(_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma
|
||||||
actual_median = np.median(y)
|
actual_median = np.median(y)
|
||||||
assert_all_close(actual_median, expected_median, atol=8e-4)
|
assert_all_close(actual_median, expected_median, atol=median_atol)
|
||||||
|
|
||||||
expected_variance = sigma**2 * (1 + (
|
expected_variance = sigma**2 * (1 + (
|
||||||
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
||||||
@ -143,4 +143,4 @@ def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
|
|||||||
assert_all_close(
|
assert_all_close(
|
||||||
actual_variance,
|
actual_variance,
|
||||||
expected_variance,
|
expected_variance,
|
||||||
rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
rtol=variance_rtol)
|
||||||
|
Loading…
Reference in New Issue
Block a user