From 85b16455b872ab9a4c02ab3394a14354533ee7ae Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 25 May 2016 15:30:00 -0700 Subject: [PATCH] Truncate random seed to fit into int during protobuf serialization (#2495) --- tensorflow/python/framework/random_seed.py | 10 +++++++--- tensorflow/python/kernel_tests/random_ops_test.py | 7 ++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py index b70f626a9ee..9f503b8f29c 100644 --- a/tensorflow/python/framework/random_seed.py +++ b/tensorflow/python/framework/random_seed.py @@ -24,6 +24,10 @@ from tensorflow.python.framework import ops _DEFAULT_GRAPH_SEED = 87654321 +_MAXINT32 = 2**31 - 1 + +def _truncate_seed(seed): + return seed % _MAXINT32 # truncate to fit into 32-bit integer def get_seed(op_seed): @@ -47,12 +51,12 @@ def get_seed(op_seed): graph_seed = ops.get_default_graph().seed if graph_seed is not None: if op_seed is not None: - return graph_seed, op_seed + return _truncate_seed(graph_seed), _truncate_seed(op_seed) else: - return graph_seed, ops.get_default_graph()._last_id + return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id) else: if op_seed is not None: - return _DEFAULT_GRAPH_SEED, op_seed + return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed) else: return None, None diff --git a/tensorflow/python/kernel_tests/random_ops_test.py b/tensorflow/python/kernel_tests/random_ops_test.py index 45b61be0c31..f4ed26b1e25 100644 --- a/tensorflow/python/kernel_tests/random_ops_test.py +++ b/tensorflow/python/kernel_tests/random_ops_test.py @@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase): def testSeed(self): for use_gpu in False, True: for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64: - sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345) - sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345) - self.assertAllEqual(sx(), sy()) + for seed in [345, 2**100, -2**100]: + sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed) + sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed) + self.assertAllEqual(sx(), sy()) def testNoCSE(self): shape = [2, 3, 4]