Truncate random seed to fit into int during protobuf serialization (#2495)
This commit is contained in:
parent
175e9f73b3
commit
85b16455b8
@ -24,6 +24,10 @@ from tensorflow.python.framework import ops
|
|||||||
|
|
||||||
|
|
||||||
_DEFAULT_GRAPH_SEED = 87654321
|
_DEFAULT_GRAPH_SEED = 87654321
|
||||||
|
_MAXINT32 = 2**31 - 1
|
||||||
|
|
||||||
|
def _truncate_seed(seed):
|
||||||
|
return seed % _MAXINT32 # truncate to fit into 32-bit integer
|
||||||
|
|
||||||
|
|
||||||
def get_seed(op_seed):
|
def get_seed(op_seed):
|
||||||
@ -47,12 +51,12 @@ def get_seed(op_seed):
|
|||||||
graph_seed = ops.get_default_graph().seed
|
graph_seed = ops.get_default_graph().seed
|
||||||
if graph_seed is not None:
|
if graph_seed is not None:
|
||||||
if op_seed is not None:
|
if op_seed is not None:
|
||||||
return graph_seed, op_seed
|
return _truncate_seed(graph_seed), _truncate_seed(op_seed)
|
||||||
else:
|
else:
|
||||||
return graph_seed, ops.get_default_graph()._last_id
|
return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id)
|
||||||
else:
|
else:
|
||||||
if op_seed is not None:
|
if op_seed is not None:
|
||||||
return _DEFAULT_GRAPH_SEED, op_seed
|
return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed)
|
||||||
else:
|
else:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase):
|
|||||||
def testSeed(self):
|
def testSeed(self):
|
||||||
for use_gpu in False, True:
|
for use_gpu in False, True:
|
||||||
for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64:
|
for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64:
|
||||||
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
|
for seed in [345, 2**100, -2**100]:
|
||||||
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
|
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
|
||||||
self.assertAllEqual(sx(), sy())
|
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
|
||||||
|
self.assertAllEqual(sx(), sy())
|
||||||
|
|
||||||
def testNoCSE(self):
|
def testNoCSE(self):
|
||||||
shape = [2, 3, 4]
|
shape = [2, 3, 4]
|
||||||
|
Loading…
Reference in New Issue
Block a user