Truncate random seed to fit into int during protobuf serialization (#2495)

This commit is contained in:
Maxim Grechkin 2016-05-25 15:30:00 -07:00 committed by Vijay Vasudevan
parent 175e9f73b3
commit 85b16455b8
2 changed files with 11 additions and 6 deletions

View File

@ -24,6 +24,10 @@ from tensorflow.python.framework import ops
_DEFAULT_GRAPH_SEED = 87654321
_MAXINT32 = 2**31 - 1
def _truncate_seed(seed):
return seed % _MAXINT32 # truncate to fit into 32-bit integer
def get_seed(op_seed):
@ -47,12 +51,12 @@ def get_seed(op_seed):
graph_seed = ops.get_default_graph().seed
if graph_seed is not None:
if op_seed is not None:
return graph_seed, op_seed
return _truncate_seed(graph_seed), _truncate_seed(op_seed)
else:
return graph_seed, ops.get_default_graph()._last_id
return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id)
else:
if op_seed is not None:
return _DEFAULT_GRAPH_SEED, op_seed
return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed)
else:
return None, None

View File

@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase):
def testSeed(self):
for use_gpu in False, True:
for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64:
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
self.assertAllEqual(sx(), sy())
for seed in [345, 2**100, -2**100]:
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
self.assertAllEqual(sx(), sy())
def testNoCSE(self):
shape = [2, 3, 4]