[TF] Add test cases to check that random ops are stateful.
Check that running the same random op multiple times in the same session rarely produces the same result. PiperOrigin-RevId: 206764062
This commit is contained in:
parent
2826d123a0
commit
217dd20c5e
@ -24,13 +24,42 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RandomNormalTest(test.TestCase):
|
||||
class RandomOpTestCommon(test.TestCase):
|
||||
|
||||
# Checks that executing the same rng_func multiple times rarely produces the
|
||||
# same result.
|
||||
def _testSingleSessionNotConstant(self,
|
||||
rng_func,
|
||||
num,
|
||||
dtype,
|
||||
min_or_mean,
|
||||
max_or_stddev,
|
||||
use_gpu,
|
||||
op_seed=None,
|
||||
graph_seed=None):
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess:
|
||||
if graph_seed is not None:
|
||||
random_seed.set_random_seed(graph_seed)
|
||||
x = rng_func([num], min_or_mean, max_or_stddev, dtype=dtype, seed=op_seed)
|
||||
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# the same output, all three outputs will be bitwise identical.
|
||||
self.assertTrue((not np.array_equal(y, z)) or
|
||||
(not np.array_equal(z, w)) or (not np.array_equal(y, w)))
|
||||
|
||||
|
||||
class RandomNormalTest(RandomOpTestCommon):
|
||||
|
||||
def _Sampler(self, num, mu, sigma, dtype, use_gpu, seed=None):
|
||||
|
||||
@ -90,6 +119,36 @@ class RandomNormalTest(test.TestCase):
|
||||
diff = rnd2 - rnd1
|
||||
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
|
||||
|
||||
def testSingleSessionNotConstant(self):
|
||||
for use_gpu in [False, True]:
|
||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
self._testSingleSessionNotConstant(
|
||||
random_ops.random_normal, 100, dt, 0.0, 1.0, use_gpu=use_gpu)
|
||||
|
||||
def testSingleSessionOpSeedNotConstant(self):
|
||||
for use_gpu in [False, True]:
|
||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
self._testSingleSessionNotConstant(
|
||||
random_ops.random_normal,
|
||||
100,
|
||||
dt,
|
||||
0.0,
|
||||
1.0,
|
||||
use_gpu=use_gpu,
|
||||
op_seed=1345)
|
||||
|
||||
def testSingleSessionGraphSeedNotConstant(self):
|
||||
for use_gpu in [False, True]:
|
||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
self._testSingleSessionNotConstant(
|
||||
random_ops.random_normal,
|
||||
100,
|
||||
dt,
|
||||
0.0,
|
||||
1.0,
|
||||
use_gpu=use_gpu,
|
||||
graph_seed=965)
|
||||
|
||||
|
||||
class TruncatedNormalTest(test.TestCase):
|
||||
|
||||
@ -187,7 +246,7 @@ class TruncatedNormalTest(test.TestCase):
|
||||
self.assertAllEqual(rnd1, rnd2)
|
||||
|
||||
|
||||
class RandomUniformTest(test.TestCase):
|
||||
class RandomUniformTest(RandomOpTestCommon):
|
||||
|
||||
def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
|
||||
|
||||
@ -291,6 +350,39 @@ class RandomUniformTest(test.TestCase):
|
||||
diff = (rnd2 - rnd1).eval()
|
||||
self.assertTrue(np.linalg.norm(diff) > 0.1)
|
||||
|
||||
def testSingleSessionNotConstant(self):
|
||||
for use_gpu in [False, True]:
|
||||
for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
|
||||
dtypes.int64):
|
||||
self._testSingleSessionNotConstant(
|
||||
random_ops.random_uniform, 100, dt, 0, 17, use_gpu=use_gpu)
|
||||
|
||||
def testSingleSessionOpSeedNotConstant(self):
|
||||
for use_gpu in [False, True]:
|
||||
for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
|
||||
dtypes.int64):
|
||||
self._testSingleSessionNotConstant(
|
||||
random_ops.random_uniform,
|
||||
100,
|
||||
dt,
|
||||
10,
|
||||
20,
|
||||
use_gpu=use_gpu,
|
||||
op_seed=1345)
|
||||
|
||||
def testSingleSessionGraphSeedNotConstant(self):
|
||||
for use_gpu in [False, True]:
|
||||
for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
|
||||
dtypes.int64):
|
||||
self._testSingleSessionNotConstant(
|
||||
random_ops.random_uniform,
|
||||
100,
|
||||
dt,
|
||||
20,
|
||||
200,
|
||||
use_gpu=use_gpu,
|
||||
graph_seed=965)
|
||||
|
||||
|
||||
class RandomShapeTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user