766 lines
27 KiB
Python
766 lines
27 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for stateful_random_ops.py."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import re
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
import six
|
|
|
|
from tensorflow.python.distribute import values as dist_values
|
|
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import config
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.kernel_tests.random import util as \
|
|
random_test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import gen_random_ops
|
|
from tensorflow.python.ops import gen_stateful_random_ops
|
|
from tensorflow.python.ops import logging_ops
|
|
from tensorflow.python.ops import stateful_random_ops as \
|
|
random
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
g_seeded = None
|
|
g_unseeded = None
|
|
|
|
|
|
GPU_FLOATS = [dtypes.float16, dtypes.float32, dtypes.float64]
|
|
CPU_FLOATS = GPU_FLOATS + [dtypes.bfloat16]
|
|
FLOATS = GPU_FLOATS
|
|
INTS = [dtypes.int32, dtypes.int64]
|
|
|
|
|
|
class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
super(StatefulRandomOpsTest, self).setUp()
|
|
physical_devices = config.list_physical_devices("CPU")
|
|
config.set_logical_device_configuration(
|
|
physical_devices[0], [
|
|
context.LogicalDeviceConfiguration(),
|
|
context.LogicalDeviceConfiguration()
|
|
])
|
|
|
|
def testCreateRNGStateIntSeed(self):
|
|
"""Tests `create_rng_state` when `seed` is int."""
|
|
# using leading 'F' to test overflow tolerance
|
|
state = random.create_rng_state(0xFFFF222233334444FFAA666677778888,
|
|
random.RNG_ALG_PHILOX)
|
|
self.assertAllEqual(
|
|
list(map(random._uint_to_int,
|
|
[0xFFAA666677778888, 0xFFFF222233334444] +
|
|
[0] * (random.PHILOX_STATE_SIZE - 2))),
|
|
state)
|
|
|
|
def assertAllDifferent(self, tensors):
|
|
"""Checks that there are no duplicate elements anywhere among the tensors.
|
|
|
|
Args:
|
|
tensors: a list of tensors. They can have different shapes.
|
|
"""
|
|
tensors = [array_ops.reshape(t, shape=[-1]) for t in tensors]
|
|
ls = array_ops.concat(tensors, axis=0).numpy().tolist()
|
|
self.assertAllEqual(len(ls), len(set(ls)))
|
|
|
|
@test_util.run_v2_only
|
|
def testNonDeterministicInts(self):
|
|
"""Tests that non_deterministic_ints returns different results every time.
|
|
|
|
This test is flaky, but with very low probability of failing.
|
|
"""
|
|
shape = [2, 3]
|
|
dtype = dtypes.int64
|
|
a = random.non_deterministic_ints(shape=shape, dtype=dtype)
|
|
self.assertAllEqual(shape, a.shape)
|
|
self.assertEqual(dtype, a.dtype)
|
|
b = random.non_deterministic_ints(shape, dtype=dtype)
|
|
self.assertAllDifferent([a, b])
|
|
|
|
@test_util.run_v2_only
|
|
def testBatchSeeds(self):
|
|
"""Test for batch seeds.
|
|
"""
|
|
shape = [2, 3]
|
|
count = 6
|
|
gen = random.Generator.from_seed(1234)
|
|
keys1 = gen._make_int64_keys(shape=shape)
|
|
keys2 = gen._make_int64_keys(shape=shape)
|
|
self.assertAllDifferent([keys1, keys2])
|
|
seeds1 = gen.make_seeds(count=count)
|
|
seeds2 = gen.make_seeds(count=count)
|
|
self.assertAllDifferent([seeds1[0, :], seeds2[0, :]])
|
|
gens = gen.split(count=count)
|
|
self.assertAllEqual(count, len(gens))
|
|
randoms = [g.uniform_full_int(shape=shape, dtype=dtypes.int32)
|
|
for g in gens]
|
|
self.assertAllDifferent(randoms)
|
|
# Tests graph mode.
|
|
@def_function.function
|
|
def f():
|
|
return gen.make_seeds(count=count)
|
|
for _ in range(3):
|
|
f()
|
|
|
|
def assertRegex(self, pattern, text):
|
|
self.assertTrue(
|
|
re.search(pattern, text),
|
|
"Can't find pattern '%s' in text '%s'" % (pattern, text))
|
|
|
|
@test_util.run_v2_only
|
|
@test_util.run_cuda_only
|
|
def testCrossDeviceSplit(self):
|
|
"""Tests that a CPU RNG can split into RNGs on GPU.
|
|
"""
|
|
with ops.device("/device:CPU:0"):
|
|
gen = random.Generator.from_seed(1234) # gen is on CPU
|
|
self.assertRegex("CPU", gen.state.device)
|
|
with ops.device(test_util.gpu_device_name()):
|
|
gens = gen.split(count=10) # gens are on GPU
|
|
self.assertRegex("GPU", gens[0].state.device)
|
|
|
|
@test_util.run_v2_only
|
|
def testReset(self):
|
|
shape = [2, 3]
|
|
gen = random.Generator.from_seed(0)
|
|
for resetter in [
|
|
lambda g: g.reset(state=[1, 2, 3]),
|
|
lambda g: g.reset_from_seed(1234),
|
|
lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]),
|
|
]:
|
|
resetter(gen)
|
|
expected_normal = gen.normal(shape)
|
|
@def_function.function
|
|
def f(resetter):
|
|
resetter(gen)
|
|
return gen.normal(shape)
|
|
def check_results(expected_normal, v):
|
|
self.assertAllEqual(expected_normal, v)
|
|
check_results(expected_normal, f(resetter))
|
|
check_results(expected_normal, f(resetter))
|
|
|
|
@test_util.run_v2_only
|
|
def testGeneratorCreation(self):
|
|
"""Tests generator creation, in both eager and tf.function.
|
|
|
|
The interaction between Generator creation and defun should be the same as
|
|
tf.Variable.
|
|
"""
|
|
shape = [2, 3]
|
|
alg = random.RNG_ALG_PHILOX
|
|
for constructor in [
|
|
lambda: random.Generator(state=[1, 2, 3], alg=alg),
|
|
lambda: random.Generator.from_seed(1234),
|
|
lambda: random.Generator.from_key_counter( # pylint: disable=g-long-lambda
|
|
key=1, counter=[2, 3], alg=alg),
|
|
]:
|
|
gen = constructor()
|
|
# Tests tf.function
|
|
expected_normal1 = gen.normal(shape)
|
|
expected_normal2 = gen.normal(shape)
|
|
global g_seeded
|
|
g_seeded = None
|
|
@def_function.function
|
|
def f(constructor):
|
|
global g_seeded
|
|
# defun'ed function should only create variables once
|
|
if g_seeded is None:
|
|
g_seeded = constructor()
|
|
return g_seeded.normal(shape)
|
|
def check_results(expected_normal, v):
|
|
self.assertAllEqual(expected_normal, v)
|
|
check_results(expected_normal1, f(constructor))
|
|
check_results(expected_normal2, f(constructor))
|
|
|
|
@parameterized.parameters([
|
|
("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX),
|
|
("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)])
|
|
@test_util.run_v2_only
|
|
def testAlg(self, name, int_id, enum_id):
|
|
g_by_name = random.Generator.from_seed(1234, name)
|
|
g_by_int = random.Generator.from_seed(1234, int_id)
|
|
g_by_enum = random.Generator.from_seed(1234, enum_id)
|
|
self.assertEqual(g_by_name.algorithm, g_by_int.algorithm)
|
|
self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm)
|
|
|
|
@test_util.run_v2_only
|
|
def testGeneratorCreationWithVar(self):
|
|
"""Tests creating generator with a variable.
|
|
"""
|
|
alg = random.RNG_ALG_PHILOX
|
|
state = [1, 2, 3]
|
|
var = variables.Variable(state, dtype=random.STATE_TYPE)
|
|
g = random.Generator(state=state, alg=alg)
|
|
g_var = random.Generator(state=var, alg=alg)
|
|
shape = [2, 3]
|
|
g.normal(shape)
|
|
g_var.normal(shape)
|
|
self.assertAllEqual(g.state.read_value(), var.read_value())
|
|
|
|
@test_util.run_v2_only
|
|
def testGeneratorCreationUnseeded(self):
|
|
"""Tests generator creation, the unseeded case."""
|
|
shape = [2, 3]
|
|
global g_unseeded
|
|
g_unseeded = None
|
|
@def_function.function
|
|
def f():
|
|
global g_unseeded
|
|
# defun'ed function should only create variables once
|
|
if g_unseeded is None:
|
|
g_unseeded = random.Generator.from_non_deterministic_state()
|
|
return g_unseeded.normal(shape)
|
|
self.assertAllEqual(shape, f().shape)
|
|
|
|
@test_util.run_v2_only
|
|
def testGeneratorCopy(self):
|
|
"""Tests copying a generator."""
|
|
g = random.Generator.from_seed(0)
|
|
g_copy = random.Generator(g)
|
|
self.assertAllEqual(g.algorithm, g_copy.algorithm)
|
|
self.assertAllEqual(g.state.read_value(), g_copy.state.read_value())
|
|
# Tests tf.function
|
|
global g_seeded
|
|
g_seeded = None
|
|
# Do the same in tf.function
|
|
@def_function.function
|
|
def f():
|
|
global g_seeded
|
|
# defun'ed function should only create variables once
|
|
if g_seeded is None:
|
|
g_seeded = random.Generator(g)
|
|
self.assertAllEqual(g.algorithm, g_seeded.algorithm)
|
|
self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
|
|
f()
|
|
|
|
@test_util.run_v1_only(
|
|
("This test is specifically for checking TF1 compatibility. "
|
|
"It cannot run under TF2."))
|
|
def testTF1(self):
|
|
seed = 1234
|
|
shape = [2, 3]
|
|
expected_normal1 = constant_op.constant(
|
|
[[0.9356609, 1.0854305, -0.93788373],
|
|
[-0.50615472, 1.31697023, 0.71375787]], dtype=dtypes.float32)
|
|
expected_normal2 = constant_op.constant(
|
|
[[-0.3964749, 0.8369565, -0.30946946],
|
|
[1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32)
|
|
with self.cached_session() as sess:
|
|
gen1 = random.Generator.from_seed(seed)
|
|
gen2 = random.Generator.from_non_deterministic_state()
|
|
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
|
|
r1 = gen1.normal(shape, dtype=dtypes.float32)
|
|
r2 = gen2.normal(shape, dtype=dtypes.float32)
|
|
def f():
|
|
return sess.run((r1, r2))
|
|
def check_results(expected_normal, v1, v2):
|
|
self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5)
|
|
self.assertAllEqual(shape, v2.shape)
|
|
check_results(expected_normal1, *f())
|
|
check_results(expected_normal2, *f())
|
|
|
|
@test_util.run_v2_only
|
|
@test_util.also_run_as_tf_function
|
|
def testEagerAndDefun(self):
|
|
"""A simple test to make sure the op works in eager and defunned mode."""
|
|
random.get_global_generator().normal((3,))
|
|
|
|
@test_util.run_v2_only
|
|
def testOpSeedSelectionAfterSetSeed(self):
|
|
"""Tests that op-seed selection is reset after reseting global generator.
|
|
|
|
Fixing GitHub issue 9171:
|
|
https://github.com/tensorflow/tensorflow/issues/9171
|
|
"""
|
|
shape = (3,)
|
|
random.get_global_generator().reset_from_seed(1)
|
|
a = random.get_global_generator().normal(shape)
|
|
random.get_global_generator().reset_from_seed(1)
|
|
b = random.get_global_generator().normal(shape)
|
|
self.assertAllEqual(a, b)
|
|
|
|
# Now do the above again using accelerated ('defun'ed) computation
|
|
@def_function.function
|
|
def f():
|
|
return random.get_global_generator().normal(shape)
|
|
|
|
random.get_global_generator().reset_from_seed(1)
|
|
c = f()
|
|
random.get_global_generator().reset_from_seed(1)
|
|
d = f()
|
|
self.assertAllEqual(c, d)
|
|
self.assertAllEqual(a, c)
|
|
|
|
@test_util.run_v2_only
|
|
def testOpSeedSelectionNotSensitive(self):
|
|
"""Test that op-seed selection is not sensitive to trivial changes.
|
|
|
|
Test that op-seed selection is not sensitive to trivial computation
|
|
(i.e. graph) changes.
|
|
|
|
Fixing b/32087099
|
|
"""
|
|
def f(include_print):
|
|
shape = constant_op.constant([5])
|
|
if include_print:
|
|
shape = logging_ops.Print(shape, [shape])
|
|
return random.get_global_generator().normal(shape)
|
|
|
|
def compare(fst_includes_print, snd_includes_print):
|
|
random.get_global_generator().reset_from_seed(50)
|
|
fst = f(fst_includes_print)
|
|
random.get_global_generator().reset_from_seed(50)
|
|
snd = f(snd_includes_print)
|
|
self.assertAllEqual(fst, snd)
|
|
# Now do the above again using accelerated (defunned) 'f'.
|
|
# Running 'f' with two different Boolean arguments should cause
|
|
# two different graphs to be generated, hence demonstrating the
|
|
# insensitivity to graph changes.
|
|
f_acc = def_function.function(f)
|
|
random.get_global_generator().reset_from_seed(50)
|
|
fst = f_acc(fst_includes_print)
|
|
random.get_global_generator().reset_from_seed(50)
|
|
snd = f_acc(snd_includes_print)
|
|
self.assertAllEqual(fst, snd)
|
|
|
|
compare(False, False)
|
|
compare(True, True)
|
|
compare(True, False)
|
|
|
|
@test_util.run_v2_only
|
|
def testKey(self):
|
|
key = 1234
|
|
gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX)
|
|
got = gen.key
|
|
self.assertAllEqual(key, got)
|
|
@def_function.function
|
|
def f():
|
|
return gen.key
|
|
got = f()
|
|
self.assertAllEqual(key, got)
|
|
|
|
@test_util.run_v2_only
|
|
def testSkip(self):
|
|
key = 1234
|
|
counter = 5678
|
|
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
|
|
delta = 432
|
|
gen.skip(delta)
|
|
new_counter = gen._state_var[0]
|
|
self.assertAllEqual(counter + delta * 256, new_counter)
|
|
|
|
def _sameAsOldRandomOps(self, device, floats):
|
|
def compare(dtype, old, new):
|
|
seed1, seed2 = 79, 25
|
|
# note how the two seeds for the old op correspond to the seed for the new
|
|
# op
|
|
with ops.device(device):
|
|
gen = random.Generator(state=[0, seed2, seed1],
|
|
alg=random.RNG_ALG_PHILOX)
|
|
|
|
# create a graph for the old op in order to call it many times
|
|
@def_function.function
|
|
def run_old():
|
|
with ops.device(device):
|
|
return old(dtype, seed1, seed2)
|
|
|
|
def run_new():
|
|
with ops.device(device):
|
|
return new(dtype, gen)
|
|
|
|
for _ in range(100):
|
|
self.assertAllEqual(run_old(), run_new())
|
|
|
|
shape = constant_op.constant([4, 7])
|
|
minval = 128
|
|
maxval = 256
|
|
|
|
# passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and
|
|
# go/gpylint-faq#undefined-loop-variable
|
|
def old_normal(dtype, seed1, seed2):
|
|
return gen_random_ops.random_standard_normal(
|
|
shape, dtype=dtype, seed=seed1, seed2=seed2)
|
|
def new_normal(dtype, gen):
|
|
return gen._standard_normal(shape, dtype=dtype)
|
|
def old_truncated_normal(dtype, seed1, seed2):
|
|
return gen_random_ops.truncated_normal(
|
|
shape, dtype=dtype, seed=seed1, seed2=seed2)
|
|
def new_truncated_normal(dtype, gen):
|
|
return gen._truncated_normal(shape, dtype=dtype)
|
|
def old_uniform_int(dtype, seed1, seed2):
|
|
minval2 = constant_op.constant(minval, dtype=dtype)
|
|
maxval2 = constant_op.constant(maxval, dtype=dtype)
|
|
return gen_random_ops.random_uniform_int(
|
|
shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2)
|
|
def new_uniform_int(dtype, gen):
|
|
return gen.uniform(shape, minval=minval, maxval=maxval, dtype=dtype)
|
|
def old_uniform(dtype, seed1, seed2):
|
|
return gen_random_ops.random_uniform(
|
|
shape, dtype=dtype, seed=seed1, seed2=seed2)
|
|
def new_uniform(dtype, gen):
|
|
return gen._uniform(shape, dtype=dtype)
|
|
|
|
for dtype in floats:
|
|
compare(dtype, old_normal, new_normal)
|
|
compare(dtype, old_truncated_normal, new_truncated_normal)
|
|
compare(dtype, old_uniform, new_uniform)
|
|
for dtype in INTS:
|
|
compare(dtype, old_uniform_int, new_uniform_int)
|
|
|
|
@test_util.run_v2_only
|
|
def testSameAsOldRandomOpsCPU(self):
|
|
"""Tests that the generated numbers are the same as the old random_ops.py.
|
|
|
|
The CPU version.
|
|
"""
|
|
self._sameAsOldRandomOps("/device:CPU:0", CPU_FLOATS)
|
|
|
|
@test_util.run_v2_only
|
|
@test_util.run_cuda_only
|
|
def testSameAsOldRandomOpsGPU(self):
|
|
"""Tests that the generated numbers are the same as the old random_ops.py.
|
|
|
|
The GPU version.
|
|
"""
|
|
self._sameAsOldRandomOps(test_util.gpu_device_name(), GPU_FLOATS)
|
|
|
|
@parameterized.parameters(INTS + [dtypes.uint32, dtypes.uint64])
|
|
@test_util.run_v2_only
|
|
@test_util.run_cuda_only
|
|
def testGPUEqualsCPU(self, dtype):
|
|
"""Tests that GPU and CPU generate the same integer outputs."""
|
|
seed = 1234
|
|
shape = [315, 49]
|
|
with ops.device("/device:CPU:0"):
|
|
cpu = random.Generator.from_seed(seed).uniform_full_int(
|
|
shape=shape, dtype=dtype)
|
|
with ops.device(test_util.gpu_device_name()):
|
|
gpu = random.Generator.from_seed(seed).uniform_full_int(
|
|
shape=shape, dtype=dtype)
|
|
self.assertAllEqual(cpu, gpu)
|
|
|
|
@parameterized.parameters(FLOATS + INTS)
|
|
@test_util.run_v2_only
|
|
def testUniformIsInRange(self, dtype):
|
|
minval = 2
|
|
maxval = 33
|
|
size = 1000
|
|
gen = random.Generator.from_seed(1234)
|
|
x = gen.uniform(
|
|
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
|
self.assertTrue(np.all(x >= minval))
|
|
self.assertTrue(np.all(x < maxval))
|
|
|
|
@parameterized.parameters(FLOATS)
|
|
@test_util.run_v2_only
|
|
def testNormalIsFinite(self, dtype):
|
|
gen = random.Generator.from_seed(1234)
|
|
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
|
self.assertTrue(np.all(np.isfinite(x)))
|
|
|
|
@parameterized.parameters(FLOATS + INTS)
|
|
@test_util.run_v2_only
|
|
def testDistributionOfUniform(self, dtype):
|
|
"""Use Pearson's Chi-squared test to test for uniformity."""
|
|
n = 1000
|
|
seed = 12
|
|
gen = random.Generator.from_seed(seed)
|
|
maxval = 1
|
|
if dtype.is_integer:
|
|
maxval = 100
|
|
x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
|
|
if maxval > 1:
|
|
# Normalize y to range [0, 1).
|
|
x = x.astype(float) / maxval
|
|
# Tests that the values are distributed amongst 10 bins with equal
|
|
# probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
|
|
# p=0.05. This test is probabilistic and would be flaky if the random
|
|
# seed were not fixed.
|
|
val = random_test_util.chi_squared(x, 10)
|
|
self.assertLess(val, 16.92)
|
|
|
|
@parameterized.parameters(FLOATS)
|
|
@test_util.run_v2_only
|
|
def testDistributionOfNormal(self, dtype):
|
|
"""Use Anderson-Darling test to test distribution appears normal."""
|
|
n = 1000
|
|
gen = random.Generator.from_seed(1234)
|
|
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
|
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
|
# test where the mean and variance are known. This test is probabilistic
|
|
# so to avoid flakiness the seed is fixed.
|
|
self.assertLess(
|
|
random_test_util.anderson_darling(x.astype(float)), 2.492)
|
|
|
|
@test_util.run_v2_only
|
|
def testErrors(self):
|
|
"""Tests that proper errors are raised.
|
|
"""
|
|
shape = [2, 3]
|
|
gen = random.Generator.from_seed(1234)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.InvalidArgumentError,
|
|
r"must have shape \[\], not"):
|
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
gen.state.handle, [0, 0], shape)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.InvalidArgumentError,
|
|
r"must have shape \[\], not"):
|
|
gen_stateful_random_ops.rng_skip(
|
|
gen.state.handle, gen.algorithm, [0, 0])
|
|
with self.assertRaisesWithPredicateMatch(
|
|
TypeError, "EagerTensor of dtype int64"):
|
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
gen.state.handle, 1.1, shape)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.InvalidArgumentError,
|
|
"Unsupported algorithm id"):
|
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
gen.state.handle, 123, shape)
|
|
var = variables.Variable([0, 0], dtype=dtypes.int32)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.InvalidArgumentError,
|
|
"dtype of RNG state variable must be int64, not"):
|
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
var.handle, random.RNG_ALG_PHILOX, shape)
|
|
var = variables.Variable([[0]], dtype=dtypes.int64)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.InvalidArgumentError,
|
|
"RNG state must have one and only one dimension, not"):
|
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
var.handle, random.RNG_ALG_PHILOX, shape)
|
|
var = variables.Variable([0], dtype=dtypes.int64)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.InvalidArgumentError,
|
|
"For the Philox algorithm, the size of state must be at least"):
|
|
gen_stateful_random_ops.stateful_standard_normal_v2(
|
|
var.handle, random.RNG_ALG_PHILOX, shape)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError,
|
|
"minval must be a scalar; got a tensor of shape "):
|
|
@def_function.function
|
|
def f():
|
|
gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"),
|
|
maxval=100, dtype="int32")
|
|
f()
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError,
|
|
"maxval must be a scalar; got a tensor of shape "):
|
|
@def_function.function
|
|
def f2():
|
|
gen.uniform(
|
|
shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100,
|
|
dtype="int32")
|
|
f2()
|
|
|
|
@test_util.run_v2_only
|
|
def testGetGlobalGeneratorWithXla(self):
|
|
"""Demonstrates using the global generator with XLA."""
|
|
if not config.list_physical_devices("XLA_CPU"):
|
|
self.skipTest("No XLA_CPU device available.")
|
|
|
|
random.set_global_generator(None)
|
|
|
|
@def_function.function(experimental_compile=True)
|
|
def make_seed():
|
|
generator = random.get_global_generator()
|
|
state = array_ops.identity(generator.state, name="state")
|
|
return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state
|
|
|
|
with ops.device("/device:XLA_CPU:0"):
|
|
seed, state = make_seed()
|
|
self.assertTrue(np.all(np.isfinite(seed.numpy())))
|
|
random.get_global_generator().reset(state)
|
|
self.assertAllEqual(make_seed()[0], seed)
|
|
|
|
@test_util.run_v2_only
|
|
def testSetGlobalGeneratorBadWithDefun(self):
|
|
"""Demonstrates that set_global_generator don't work properly with defun.
|
|
"""
|
|
shape = (3,)
|
|
|
|
@def_function.function
|
|
def f():
|
|
return random.get_global_generator().normal(shape)
|
|
|
|
random.set_global_generator(random.Generator.from_seed(50))
|
|
with self.assertRaisesWithPredicateMatch(
|
|
errors.NotFoundError, "Resource .+ does not exist"):
|
|
_ = f()
|
|
random.set_global_generator(random.Generator.from_seed(50))
|
|
_ = f()
|
|
|
|
@test_util.run_v2_only
|
|
def testFunctionArg(self):
|
|
"""Tests that RNG can be used as tf.function's argument.
|
|
"""
|
|
shape = [2, 3]
|
|
@def_function.function
|
|
def f(gen):
|
|
return gen.normal(shape)
|
|
g1 = random.Generator.from_seed(1)
|
|
g2 = random.Generator.from_seed(1)
|
|
res1 = f(g1)
|
|
res2 = g2.normal(shape)
|
|
self.assertAllEqual(res1, res2)
|
|
self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
|
|
|
|
@test_util.run_v2_only
|
|
def testFunArgAlgIsInt(self):
|
|
"""Tests that `algorithm` is `int` when reconstructed from composite tensor.
|
|
"""
|
|
@def_function.function
|
|
def f(g):
|
|
self.assertIsInstance(g.algorithm, six.integer_types)
|
|
return g.make_seeds(), g
|
|
gen = random.Generator.from_seed(123, alg="philox")
|
|
f(gen)
|
|
|
|
@test_util.run_v2_only
|
|
def testLimitedRetracingWithCompositeTensors(self):
|
|
"""Tests that RNGs with the same shape/dtype won't cause retracing.
|
|
"""
|
|
trace_count = [0]
|
|
|
|
@def_function.function
|
|
def f(x):
|
|
trace_count[0] += 1
|
|
return x.normal([])
|
|
|
|
f(random.Generator.from_seed(1))
|
|
f(random.Generator.from_seed(2))
|
|
self.assertEqual(trace_count[0], 1)
|
|
|
|
def testMostSpecificCompatibleType(self):
|
|
"""Tests GeneratorSpec.most_specific_compatible_type.
|
|
"""
|
|
spec = random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32)
|
|
res = spec.most_specific_compatible_type(
|
|
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32))
|
|
self.assertEqual(spec, res)
|
|
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
|
spec.most_specific_compatible_type(
|
|
tensor_spec.TensorSpec(shape=(2, 3), dtype=dtypes.int32))
|
|
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
|
spec.most_specific_compatible_type(
|
|
random.GeneratorSpec(shape=(2, 4), dtype=dtypes.int32))
|
|
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
|
spec.most_specific_compatible_type(
|
|
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
|
|
|
@test_util.run_v2_only
|
|
@test_util.run_cuda_only
|
|
def testMirroredStratSeq(self):
|
|
"""Tests RNG/MirrorStrategy interaction #1.
|
|
|
|
If an RNG is created outside strategy.scope(), all replicas will access the
|
|
same RNG object, and accesses are serialized.
|
|
"""
|
|
shape = [3, 4]
|
|
dtype = dtypes.int32
|
|
gen = random.Generator.from_seed(1234)
|
|
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
|
with strat.scope():
|
|
def f():
|
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t = array_ops.stack([t1, t2])
|
|
return t
|
|
results = strat.extended.call_for_each_replica(
|
|
fn=f)
|
|
values = results.values
|
|
self.assertAllEqual(2, len(values))
|
|
self.assertAllDifferent(values)
|
|
|
|
@test_util.run_v2_only
|
|
def testMirroredStratParaSyncDisallowed(self):
|
|
"""Tests that generator creation in MirroredStrategy is disallowed.
|
|
"""
|
|
creators = [
|
|
lambda: random.Generator.from_seed(1234),
|
|
random.Generator.from_non_deterministic_state,
|
|
]
|
|
shape = [3, 4]
|
|
dtype = dtypes.int32
|
|
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
|
for creator in creators:
|
|
with strat.scope():
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError, "disallowed"):
|
|
creator() # pylint: disable=cell-var-from-loop
|
|
def f():
|
|
gen = creator() # pylint: disable=cell-var-from-loop
|
|
return gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
with self.assertRaisesWithPredicateMatch(
|
|
ValueError, "disallowed"):
|
|
strat.extended.call_for_each_replica(fn=f)
|
|
|
|
@test_util.run_v2_only
|
|
def testMirroredStratParaAsync(self):
|
|
"""Tests RNG/MirrorStrategy interaction #2.
|
|
|
|
The user can create n independent RNGs outside strategy.scope(), where n
|
|
is the number of replicas, and give one to each replica. The replicas can
|
|
thus get different random-number streams.
|
|
"""
|
|
shape = [3, 4]
|
|
dtype = dtypes.int32
|
|
gens = random.get_global_generator().split(count=2)
|
|
devices = ["cpu:0", "cpu:1"]
|
|
strat = MirroredStrategy(devices=devices)
|
|
# Use `PerReplica` to specify which `gen` is sent to which replica
|
|
gens = dist_values.PerReplica([[g] for g in gens])
|
|
with strat.scope():
|
|
def f(gen):
|
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
|
t = array_ops.stack([t1, t2])
|
|
return t
|
|
results = strat.extended.call_for_each_replica(
|
|
fn=f, args=gens)
|
|
local_results = strat.experimental_local_results(results)
|
|
self.assertAllEqual(2, len(local_results))
|
|
self.assertAllDifferent(local_results)
|
|
|
|
@test_util.run_v2_only
|
|
def testUniformFullInt(self):
|
|
"""Tests full-range int uniform.
|
|
"""
|
|
shape = [3, 4]
|
|
dtype = dtypes.int32
|
|
g = random.Generator.from_seed(1)
|
|
r1 = g.uniform(shape=shape, dtype=dtype, minval=None)
|
|
g = random.Generator.from_seed(1)
|
|
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
|
|
self.assertAllEqual(r1, r2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|