Changes StatelessRandomGetKeyCounterAlg to pick RNG_DEFAULT on TPU instead of RNG_THREE_FRY. The latter causes an OOM for StatelessRandomNormalV2 because of excessive TPU padding. The old V1 ops also pick RNG_DEFAULT.

PiperOrigin-RevId: 341109746
Change-Id: Ia325ed8a40b6e5f9cb00d1cfdbb1ada697c5db1f
This commit is contained in:
Peng Wang 2020-11-06 13:35:47 -08:00 committed by TensorFlower Gardener
parent 498c9045ff
commit 70d22236db
3 changed files with 39 additions and 3 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -63,6 +64,18 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
lambda x: stateless.stateless_random_normal([], seed=x), [x])
f([1, 2])
def testLargeNormal(self):
"""Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
with self.session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
key, counter, alg = (gen_stateless_random_ops_v2.
stateless_random_get_key_counter_alg(seed_t))
x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
alg=alg)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
self.assertAllEqual([1024, 32000], y.shape)
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
with self.session(), self.test_scope():

View File

@ -42,6 +42,10 @@ namespace {
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
if (alg == RNG_ALG_PHILOX) {
return xla::RandomAlgorithm::RNG_PHILOX;
} else if (alg == RNG_ALG_THREEFRY) {
return xla::RandomAlgorithm::RNG_THREE_FRY;
} else if (alg == RNG_ALG_XLA_DEFAULT) {
return xla::RandomAlgorithm::RNG_DEFAULT;
}
return xla::RandomAlgorithm::RNG_THREE_FRY;
}
@ -49,6 +53,10 @@ inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
return RNG_ALG_PHILOX;
} else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
return RNG_ALG_THREEFRY;
} else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
return RNG_ALG_XLA_DEFAULT;
}
return RNG_ALG_THREEFRY;
}
@ -84,7 +92,7 @@ std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
auto counter_shape =
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
auto counter = xla::Zeros(key.builder(), counter_shape);
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
return std::make_tuple(key, counter, RNG_ALG_XLA_DEFAULT);
}
}

View File

@ -18,15 +18,30 @@ limitations under the License.
namespace tensorflow {
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
enum Algorithm {
// The Philox algorithm, as described in paper
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
RNG_ALG_PHILOX = 1,
// The ThreeFry algorithm, as described in paper
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
RNG_ALG_THREEFRY = 2,
// An algorithm suitable for TPU. Only available on XLA devices.
RNG_ALG_XLA_DEFAULT = 3
};
static constexpr int RNG_KEY_SIZE = 1;
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
inline int GetCounterSize(Algorithm alg) {
if (alg == RNG_ALG_PHILOX) {
return 2;
} else if (alg == RNG_ALG_THREEFRY) {
return 1;
} else if (alg == RNG_ALG_XLA_DEFAULT) {
return 1;
}
return 1;
return 2;
}
} // end namespace tensorflow