[XLA] Shard PrngTest.ScalarBF16Tests, increase the timeout

The test occasaionally timed out. Fix this by:
- parameterizing ScalarBF16Tests so that different test methods can be sharded.
- increasing the number of shards on the test binary
- increasing the timeout of the test binary

PiperOrigin-RevId: 254224789
This commit is contained in:
David Majnemer 2019-06-20 10:33:25 -07:00 committed by TensorFlower Gardener
parent e346de4538
commit 68c6744dd4
2 changed files with 32 additions and 23 deletions

View File

@ -1547,7 +1547,9 @@ xla_test(
xla_test(
name = "prng_test",
timeout = "long",
srcs = ["prng_test.cc"],
shard_count = 6,
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:literal",

View File

@ -81,29 +81,35 @@ XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
// TODO(b/71543667): Fix Rng ops on LLVM backends.
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(
DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) {
for (int64 seed = 0; seed < 100; ++seed) {
// The largest negative number smaller than zero in bf16 that's not
// denormalized.
int32 low_raw = 0x80800000;
const float low = reinterpret_cast<const float&>(low_raw);
float high = 0.0f;
UniformTest<bfloat16>(static_cast<bfloat16>(low),
static_cast<bfloat16>(high), {}, /*seed=*/seed);
using ScalarBF16TestCase = std::tuple<int64, std::pair<float, float>>;
// Test odd and even values.
UniformTest<bfloat16>(static_cast<bfloat16>(32.75),
static_cast<bfloat16>(33), {}, /*seed=*/seed);
UniformTest<bfloat16>(static_cast<bfloat16>(32.50),
static_cast<bfloat16>(32.75), {}, /*seed=*/seed);
UniformTest<bfloat16>(static_cast<bfloat16>(-33.00),
static_cast<bfloat16>(-32.75), {}, /*seed=*/seed);
UniformTest<bfloat16>(static_cast<bfloat16>(-32.75),
static_cast<bfloat16>(-32.50), {}, /*seed=*/seed);
}
class ScalarBF16Test
: public PrngTest,
public ::testing::WithParamInterface<ScalarBF16TestCase> {};
XLA_TEST_P(ScalarBF16Test,
DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(DISABLED_ON_CPU(DoIt)))) {
auto test_params = GetParam();
UniformTest<bfloat16>(static_cast<bfloat16>(std::get<1>(test_params).first),
static_cast<bfloat16>(std::get<1>(test_params).second),
{},
/*seed=*/std::get<0>(test_params));
}
INSTANTIATE_TEST_SUITE_P(
ScalarBF16TestInstance, ScalarBF16Test,
::testing::Combine(
::testing::Range<int64>(0, 100),
::testing::Values(
// The largest negative number smaller than zero in bf16 that's not
// denormalized.
std::make_pair(static_cast<float>(-bfloat16::min_positive_normal()),
0.0f),
// Test odd and even values.
std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f),
std::make_pair(-33.00f, -32.75f),
std::make_pair(-32.75f, -32.50f))));
// TODO(b/71543667): Fix Rng ops on LLVM backends.
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(
@ -220,7 +226,8 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
}
// This tests demonstrates the global seeding behavior.
// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is
// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output
// is
// fixed (i.e., there is a single output for a given seed);
// * If no seed is passed in then the output of every call can be different;
XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
@ -280,8 +287,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
}
// This test verifies that the two RNG instructions with the same parameters in
// the same HloComputation produces different values.
// This test verifies that the two RNG instructions with the same parameters
// in the same HloComputation produces different values.
XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) {
// Build a U[0,1) computation.
auto build_computation = [this]() {