Changes StatelessRandomGetKeyCounterAlg to pick RNG_DEFAULT on TPU instead of RNG_THREE_FRY. The latter causes an OOM for StatelessRandomNormalV2 because of excessive TPU padding. The old V1 ops also pick RNG_DEFAULT.
PiperOrigin-RevId: 341109746 Change-Id: Ia325ed8a40b6e5f9cb00d1cfdbb1ada697c5db1f
This commit is contained in:
parent
498c9045ff
commit
70d22236db
@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -63,6 +64,18 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
||||
f([1, 2])
|
||||
|
||||
def testLargeNormal(self):
|
||||
"""Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
|
||||
with self.session() as sess, self.test_scope():
|
||||
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||
key, counter, alg = (gen_stateless_random_ops_v2.
|
||||
stateless_random_get_key_counter_alg(seed_t))
|
||||
x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
|
||||
alg=alg)
|
||||
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
|
||||
self.assertAllEqual([1024, 32000], y.shape)
|
||||
|
||||
def testDeterminism(self):
|
||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||
with self.session(), self.test_scope():
|
||||
|
@ -42,6 +42,10 @@ namespace {
|
||||
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return xla::RandomAlgorithm::RNG_PHILOX;
|
||||
} else if (alg == RNG_ALG_THREEFRY) {
|
||||
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||
} else if (alg == RNG_ALG_XLA_DEFAULT) {
|
||||
return xla::RandomAlgorithm::RNG_DEFAULT;
|
||||
}
|
||||
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||
}
|
||||
@ -49,6 +53,10 @@ inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
||||
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
|
||||
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
|
||||
return RNG_ALG_PHILOX;
|
||||
} else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
|
||||
return RNG_ALG_THREEFRY;
|
||||
} else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
|
||||
return RNG_ALG_XLA_DEFAULT;
|
||||
}
|
||||
return RNG_ALG_THREEFRY;
|
||||
}
|
||||
@ -84,7 +92,7 @@ std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
|
||||
auto counter_shape =
|
||||
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
||||
auto counter = xla::Zeros(key.builder(), counter_shape);
|
||||
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
|
||||
return std::make_tuple(key, counter, RNG_ALG_XLA_DEFAULT);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,15 +18,30 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
|
||||
enum Algorithm {
|
||||
// The Philox algorithm, as described in paper
|
||||
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
|
||||
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
|
||||
RNG_ALG_PHILOX = 1,
|
||||
// The ThreeFry algorithm, as described in paper
|
||||
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
|
||||
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
|
||||
RNG_ALG_THREEFRY = 2,
|
||||
// An algorithm suitable for TPU. Only available on XLA devices.
|
||||
RNG_ALG_XLA_DEFAULT = 3
|
||||
};
|
||||
|
||||
static constexpr int RNG_KEY_SIZE = 1;
|
||||
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
|
||||
inline int GetCounterSize(Algorithm alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return 2;
|
||||
} else if (alg == RNG_ALG_THREEFRY) {
|
||||
return 1;
|
||||
} else if (alg == RNG_ALG_XLA_DEFAULT) {
|
||||
return 1;
|
||||
}
|
||||
return 1;
|
||||
return 2;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user