[XLA] Shard PrngTest.ScalarBF16Tests, increase the timeout
The test occasaionally timed out. Fix this by: - parameterizing ScalarBF16Tests so that different test methods can be sharded. - increasing the number of shards on the test binary - increasing the timeout of the test binary PiperOrigin-RevId: 254224789
This commit is contained in:
parent
e346de4538
commit
68c6744dd4
@ -1547,7 +1547,9 @@ xla_test(
|
|||||||
|
|
||||||
xla_test(
|
xla_test(
|
||||||
name = "prng_test",
|
name = "prng_test",
|
||||||
|
timeout = "long",
|
||||||
srcs = ["prng_test.cc"],
|
srcs = ["prng_test.cc"],
|
||||||
|
shard_count = 6,
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
@ -81,28 +81,34 @@ XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
|
|||||||
|
|
||||||
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
||||||
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
||||||
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(
|
using ScalarBF16TestCase = std::tuple<int64, std::pair<float, float>>;
|
||||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) {
|
|
||||||
for (int64 seed = 0; seed < 100; ++seed) {
|
class ScalarBF16Test
|
||||||
|
: public PrngTest,
|
||||||
|
public ::testing::WithParamInterface<ScalarBF16TestCase> {};
|
||||||
|
|
||||||
|
XLA_TEST_P(ScalarBF16Test,
|
||||||
|
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(DISABLED_ON_CPU(DoIt)))) {
|
||||||
|
auto test_params = GetParam();
|
||||||
|
UniformTest<bfloat16>(static_cast<bfloat16>(std::get<1>(test_params).first),
|
||||||
|
static_cast<bfloat16>(std::get<1>(test_params).second),
|
||||||
|
{},
|
||||||
|
/*seed=*/std::get<0>(test_params));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ScalarBF16TestInstance, ScalarBF16Test,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Range<int64>(0, 100),
|
||||||
|
::testing::Values(
|
||||||
// The largest negative number smaller than zero in bf16 that's not
|
// The largest negative number smaller than zero in bf16 that's not
|
||||||
// denormalized.
|
// denormalized.
|
||||||
int32 low_raw = 0x80800000;
|
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
|
||||||
const float low = reinterpret_cast<const float&>(low_raw);
|
0.0f),
|
||||||
float high = 0.0f;
|
|
||||||
UniformTest<bfloat16>(static_cast<bfloat16>(low),
|
|
||||||
static_cast<bfloat16>(high), {}, /*seed=*/seed);
|
|
||||||
|
|
||||||
// Test odd and even values.
|
// Test odd and even values.
|
||||||
UniformTest<bfloat16>(static_cast<bfloat16>(32.75),
|
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
||||||
static_cast<bfloat16>(33), {}, /*seed=*/seed);
|
std::make_pair(-33.00f, -32.75f),
|
||||||
UniformTest<bfloat16>(static_cast<bfloat16>(32.50),
|
std::make_pair(-32.75f, -32.50f))));
|
||||||
static_cast<bfloat16>(32.75), {}, /*seed=*/seed);
|
|
||||||
UniformTest<bfloat16>(static_cast<bfloat16>(-33.00),
|
|
||||||
static_cast<bfloat16>(-32.75), {}, /*seed=*/seed);
|
|
||||||
UniformTest<bfloat16>(static_cast<bfloat16>(-32.75),
|
|
||||||
static_cast<bfloat16>(-32.50), {}, /*seed=*/seed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
||||||
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
||||||
@ -220,7 +226,8 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This tests demonstrates the global seeding behavior.
|
// This tests demonstrates the global seeding behavior.
|
||||||
// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is
|
// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output
|
||||||
|
// is
|
||||||
// fixed (i.e., there is a single output for a given seed);
|
// fixed (i.e., there is a single output for a given seed);
|
||||||
// * If no seed is passed in then the output of every call can be different;
|
// * If no seed is passed in then the output of every call can be different;
|
||||||
XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
||||||
@ -280,8 +287,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
|||||||
EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
|
EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test verifies that the two RNG instructions with the same parameters in
|
// This test verifies that the two RNG instructions with the same parameters
|
||||||
// the same HloComputation produces different values.
|
// in the same HloComputation produces different values.
|
||||||
XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) {
|
XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) {
|
||||||
// Build a U[0,1) computation.
|
// Build a U[0,1) computation.
|
||||||
auto build_computation = [this]() {
|
auto build_computation = [this]() {
|
||||||
|
Loading…
Reference in New Issue
Block a user