Makes sure that the CPU, GPU and XLA kernels for the new stateful random ops generate the same integer outputs (using Philox algorithm).
PiperOrigin-RevId: 245848768
This commit is contained in:
parent
9428abe167
commit
0fbe859237
@ -52,13 +52,14 @@ def xla_device_name():
|
||||
|
||||
|
||||
ALGS = [random.RNG_ALG_PHILOX, random.RNG_ALG_THREEFRY]
|
||||
INTS = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64]
|
||||
|
||||
|
||||
# TODO(wangpeng): use parametrized tests to test both ThreeFry and Philox
|
||||
class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
"""Test cases for stateful random-number generator operators."""
|
||||
|
||||
_ints = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64]
|
||||
_ints = INTS
|
||||
_floats = [dtypes.bfloat16, dtypes.float32]
|
||||
|
||||
@parameterized.parameters(ALGS)
|
||||
@ -201,6 +202,20 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self.assertAllEqual([(size+1)//2-1, counter_high+1, key],
|
||||
gen.state.read_value())
|
||||
|
||||
@parameterized.parameters(INTS)
|
||||
@test_util.run_v2_only
|
||||
def testXLAEqualsCPU(self, dtype):
|
||||
"""Tests that XLA and CPU kernels generate the same integers."""
|
||||
seed = 1234
|
||||
shape = [315, 49]
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = (random.Generator(seed=seed, algorithm=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
with ops.device(xla_device_name()):
|
||||
xla = (random.Generator(seed=seed, algorithm=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
|
@ -353,6 +353,21 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""
|
||||
self._sameAsOldRandomOps(test_util.gpu_device_name(), GPU_FLOATS)
|
||||
|
||||
@parameterized.parameters(INTS + [dtypes.uint32, dtypes.uint64])
|
||||
@test_util.run_v2_only
|
||||
@test_util.run_cuda_only
|
||||
def testGPUEqualsCPU(self, dtype):
|
||||
"""Tests that GPU and CPU generate the same integer outputs."""
|
||||
seed = 1234
|
||||
shape = [315, 49]
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = random.Generator(seed=seed).uniform_full_int(
|
||||
shape=shape, dtype=dtype)
|
||||
with ops.device(test_util.gpu_device_name()):
|
||||
gpu = random.Generator(seed=seed).uniform_full_int(
|
||||
shape=shape, dtype=dtype)
|
||||
self.assertAllEqual(cpu, gpu)
|
||||
|
||||
@parameterized.parameters(FLOATS + INTS)
|
||||
@test_util.run_v2_only
|
||||
def testUniformIsInRange(self, dtype):
|
||||
|
Loading…
x
Reference in New Issue
Block a user