Changes StatelessRandomGetKeyCounterAlg to pick RNG_DEFAULT on TPU instead of RNG_THREE_FRY. The latter causes an OOM for StatelessRandomNormalV2 because of excessive TPU padding. The old V1 ops also pick RNG_DEFAULT.
PiperOrigin-RevId: 341109746 Change-Id: Ia325ed8a40b6e5f9cb00d1cfdbb1ada697c5db1f
This commit is contained in:
parent
498c9045ff
commit
70d22236db
@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.kernel_tests.random import util as \
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
random_test_util
|
random_test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -63,6 +64,18 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
||||||
f([1, 2])
|
f([1, 2])
|
||||||
|
|
||||||
|
def testLargeNormal(self):
|
||||||
|
"""Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
|
||||||
|
with self.session() as sess, self.test_scope():
|
||||||
|
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||||
|
key, counter, alg = (gen_stateless_random_ops_v2.
|
||||||
|
stateless_random_get_key_counter_alg(seed_t))
|
||||||
|
x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||||
|
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
|
||||||
|
alg=alg)
|
||||||
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
|
||||||
|
self.assertAllEqual([1024, 32000], y.shape)
|
||||||
|
|
||||||
def testDeterminism(self):
|
def testDeterminism(self):
|
||||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||||
with self.session(), self.test_scope():
|
with self.session(), self.test_scope():
|
||||||
|
@ -42,6 +42,10 @@ namespace {
|
|||||||
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
||||||
if (alg == RNG_ALG_PHILOX) {
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
return xla::RandomAlgorithm::RNG_PHILOX;
|
return xla::RandomAlgorithm::RNG_PHILOX;
|
||||||
|
} else if (alg == RNG_ALG_THREEFRY) {
|
||||||
|
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||||
|
} else if (alg == RNG_ALG_XLA_DEFAULT) {
|
||||||
|
return xla::RandomAlgorithm::RNG_DEFAULT;
|
||||||
}
|
}
|
||||||
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||||
}
|
}
|
||||||
@ -49,6 +53,10 @@ inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
|||||||
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
|
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
|
||||||
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
|
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
|
||||||
return RNG_ALG_PHILOX;
|
return RNG_ALG_PHILOX;
|
||||||
|
} else if (alg == xla::RandomAlgorithm::RNG_THREE_FRY) {
|
||||||
|
return RNG_ALG_THREEFRY;
|
||||||
|
} else if (alg == xla::RandomAlgorithm::RNG_DEFAULT) {
|
||||||
|
return RNG_ALG_XLA_DEFAULT;
|
||||||
}
|
}
|
||||||
return RNG_ALG_THREEFRY;
|
return RNG_ALG_THREEFRY;
|
||||||
}
|
}
|
||||||
@ -84,7 +92,7 @@ std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
|
|||||||
auto counter_shape =
|
auto counter_shape =
|
||||||
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
||||||
auto counter = xla::Zeros(key.builder(), counter_shape);
|
auto counter = xla::Zeros(key.builder(), counter_shape);
|
||||||
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
|
return std::make_tuple(key, counter, RNG_ALG_XLA_DEFAULT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,15 +18,30 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
|
enum Algorithm {
|
||||||
|
// The Philox algorithm, as described in paper
|
||||||
|
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
|
||||||
|
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
|
||||||
|
RNG_ALG_PHILOX = 1,
|
||||||
|
// The ThreeFry algorithm, as described in paper
|
||||||
|
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
|
||||||
|
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
|
||||||
|
RNG_ALG_THREEFRY = 2,
|
||||||
|
// An algorithm suitable for TPU. Only available on XLA devices.
|
||||||
|
RNG_ALG_XLA_DEFAULT = 3
|
||||||
|
};
|
||||||
|
|
||||||
static constexpr int RNG_KEY_SIZE = 1;
|
static constexpr int RNG_KEY_SIZE = 1;
|
||||||
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
|
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
|
||||||
inline int GetCounterSize(Algorithm alg) {
|
inline int GetCounterSize(Algorithm alg) {
|
||||||
if (alg == RNG_ALG_PHILOX) {
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
} else if (alg == RNG_ALG_THREEFRY) {
|
||||||
return 1;
|
return 1;
|
||||||
|
} else if (alg == RNG_ALG_XLA_DEFAULT) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user