diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index ec52e3b51dc..ef56f4f0f66 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1547,7 +1547,9 @@ xla_test( xla_test( name = "prng_test", + timeout = "long", srcs = ["prng_test.cc"], + shard_count = 6, deps = [ ":test_macros_header", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index e49bcf26bd6..f35b30ff90d 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -81,29 +81,35 @@ XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. // TODO(b/122047800): Interpreter does not support BF16 for RNG ops. -XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER( - DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) { - for (int64 seed = 0; seed < 100; ++seed) { - // The largest negative number smaller than zero in bf16 that's not - // denormalized. - int32 low_raw = 0x80800000; - const float low = reinterpret_cast(low_raw); - float high = 0.0f; - UniformTest(static_cast(low), - static_cast(high), {}, /*seed=*/seed); +using ScalarBF16TestCase = std::tuple>; - // Test odd and even values. - UniformTest(static_cast(32.75), - static_cast(33), {}, /*seed=*/seed); - UniformTest(static_cast(32.50), - static_cast(32.75), {}, /*seed=*/seed); - UniformTest(static_cast(-33.00), - static_cast(-32.75), {}, /*seed=*/seed); - UniformTest(static_cast(-32.75), - static_cast(-32.50), {}, /*seed=*/seed); - } +class ScalarBF16Test + : public PrngTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ScalarBF16Test, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(DISABLED_ON_CPU(DoIt)))) { + auto test_params = GetParam(); + UniformTest(static_cast(std::get<1>(test_params).first), + static_cast(std::get<1>(test_params).second), + {}, + /*seed=*/std::get<0>(test_params)); } +INSTANTIATE_TEST_SUITE_P( + ScalarBF16TestInstance, ScalarBF16Test, + ::testing::Combine( + ::testing::Range(0, 100), + ::testing::Values( + // The largest negative number smaller than zero in bf16 that's not + // denormalized. + std::make_pair(static_cast(-bfloat16::min_positive_normal()), + 0.0f), + // Test odd and even values. + std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f), + std::make_pair(-33.00f, -32.75f), + std::make_pair(-32.75f, -32.50f)))); + // TODO(b/71543667): Fix Rng ops on LLVM backends. // TODO(b/122047800): Interpreter does not support BF16 for RNG ops. XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( @@ -220,7 +226,8 @@ XLA_TEST_F(PrngTest, MapUsingRng) { } // This tests demonstrates the global seeding behavior. -// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is +// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output +// is // fixed (i.e., there is a single output for a given seed); // * If no seed is passed in then the output of every call can be different; XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { @@ -280,8 +287,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } -// This test verifies that the two RNG instructions with the same parameters in -// the same HloComputation produces different values. +// This test verifies that the two RNG instructions with the same parameters +// in the same HloComputation produces different values. XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) { // Build a U[0,1) computation. auto build_computation = [this]() {