Truncate random seed to fit into int during protobuf serialization (#2495)
This commit is contained in:
parent
175e9f73b3
commit
85b16455b8
@ -24,6 +24,10 @@ from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
_DEFAULT_GRAPH_SEED = 87654321
|
||||
_MAXINT32 = 2**31 - 1
|
||||
|
||||
def _truncate_seed(seed):
|
||||
return seed % _MAXINT32 # truncate to fit into 32-bit integer
|
||||
|
||||
|
||||
def get_seed(op_seed):
|
||||
@ -47,12 +51,12 @@ def get_seed(op_seed):
|
||||
graph_seed = ops.get_default_graph().seed
|
||||
if graph_seed is not None:
|
||||
if op_seed is not None:
|
||||
return graph_seed, op_seed
|
||||
return _truncate_seed(graph_seed), _truncate_seed(op_seed)
|
||||
else:
|
||||
return graph_seed, ops.get_default_graph()._last_id
|
||||
return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id)
|
||||
else:
|
||||
if op_seed is not None:
|
||||
return _DEFAULT_GRAPH_SEED, op_seed
|
||||
return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase):
|
||||
def testSeed(self):
|
||||
for use_gpu in False, True:
|
||||
for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64:
|
||||
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
|
||||
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
|
||||
self.assertAllEqual(sx(), sy())
|
||||
for seed in [345, 2**100, -2**100]:
|
||||
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
|
||||
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
|
||||
self.assertAllEqual(sx(), sy())
|
||||
|
||||
def testNoCSE(self):
|
||||
shape = [2, 3, 4]
|
||||
|
Loading…
Reference in New Issue
Block a user