[XLA] Shard PrngTest.ScalarBF16Tests, increase the timeout
The test occasaionally timed out. Fix this by: - parameterizing ScalarBF16Tests so that different test methods can be sharded. - increasing the number of shards on the test binary - increasing the timeout of the test binary PiperOrigin-RevId: 254224789
This commit is contained in:
parent
e346de4538
commit
68c6744dd4
@ -1547,7 +1547,9 @@ xla_test(
|
||||
|
||||
xla_test(
|
||||
name = "prng_test",
|
||||
timeout = "long",
|
||||
srcs = ["prng_test.cc"],
|
||||
shard_count = 6,
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -81,29 +81,35 @@ XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
|
||||
|
||||
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
||||
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
||||
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(
|
||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) {
|
||||
for (int64 seed = 0; seed < 100; ++seed) {
|
||||
// The largest negative number smaller than zero in bf16 that's not
|
||||
// denormalized.
|
||||
int32 low_raw = 0x80800000;
|
||||
const float low = reinterpret_cast<const float&>(low_raw);
|
||||
float high = 0.0f;
|
||||
UniformTest<bfloat16>(static_cast<bfloat16>(low),
|
||||
static_cast<bfloat16>(high), {}, /*seed=*/seed);
|
||||
using ScalarBF16TestCase = std::tuple<int64, std::pair<float, float>>;
|
||||
|
||||
// Test odd and even values.
|
||||
UniformTest<bfloat16>(static_cast<bfloat16>(32.75),
|
||||
static_cast<bfloat16>(33), {}, /*seed=*/seed);
|
||||
UniformTest<bfloat16>(static_cast<bfloat16>(32.50),
|
||||
static_cast<bfloat16>(32.75), {}, /*seed=*/seed);
|
||||
UniformTest<bfloat16>(static_cast<bfloat16>(-33.00),
|
||||
static_cast<bfloat16>(-32.75), {}, /*seed=*/seed);
|
||||
UniformTest<bfloat16>(static_cast<bfloat16>(-32.75),
|
||||
static_cast<bfloat16>(-32.50), {}, /*seed=*/seed);
|
||||
}
|
||||
class ScalarBF16Test
|
||||
: public PrngTest,
|
||||
public ::testing::WithParamInterface<ScalarBF16TestCase> {};
|
||||
|
||||
XLA_TEST_P(ScalarBF16Test,
|
||||
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(DISABLED_ON_CPU(DoIt)))) {
|
||||
auto test_params = GetParam();
|
||||
UniformTest<bfloat16>(static_cast<bfloat16>(std::get<1>(test_params).first),
|
||||
static_cast<bfloat16>(std::get<1>(test_params).second),
|
||||
{},
|
||||
/*seed=*/std::get<0>(test_params));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ScalarBF16TestInstance, ScalarBF16Test,
|
||||
::testing::Combine(
|
||||
::testing::Range<int64>(0, 100),
|
||||
::testing::Values(
|
||||
// The largest negative number smaller than zero in bf16 that's not
|
||||
// denormalized.
|
||||
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
|
||||
0.0f),
|
||||
// Test odd and even values.
|
||||
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
|
||||
std::make_pair(-33.00f, -32.75f),
|
||||
std::make_pair(-32.75f, -32.50f))));
|
||||
|
||||
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
||||
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
||||
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(
|
||||
@ -220,7 +226,8 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
|
||||
}
|
||||
|
||||
// This tests demonstrates the global seeding behavior.
|
||||
// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is
|
||||
// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output
|
||||
// is
|
||||
// fixed (i.e., there is a single output for a given seed);
|
||||
// * If no seed is passed in then the output of every call can be different;
|
||||
XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
||||
@ -280,8 +287,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
|
||||
}
|
||||
|
||||
// This test verifies that the two RNG instructions with the same parameters in
|
||||
// the same HloComputation produces different values.
|
||||
// This test verifies that the two RNG instructions with the same parameters
|
||||
// in the same HloComputation produces different values.
|
||||
XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) {
|
||||
// Build a U[0,1) computation.
|
||||
auto build_computation = [this]() {
|
||||
|
Loading…
Reference in New Issue
Block a user