Adds R1 test for ReduceWindow.

PiperOrigin-RevId: 183345779
This commit is contained in:
Tayo Oguntebi 2018-01-25 23:37:20 -08:00 committed by TensorFlower Gardener
parent e4912296bc
commit ffda6079ed

View File

@ -1016,9 +1016,8 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
::testing::tuple<R2ReduceWindowTestData, bool>> { ::testing::tuple<R2ReduceWindowTestData, bool>> {
protected: protected:
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
};
TEST_P(R2ReduceWindowTest, Add) { void DoIt() {
ComputationBuilder b(client_, TestName()); ComputationBuilder b(client_, TestName());
const auto& param = ::testing::get<0>(GetParam()); const auto& param = ::testing::get<0>(GetParam());
CHECK(param.reducer == kAdd); CHECK(param.reducer == kAdd);
@ -1047,6 +1046,9 @@ TEST_P(R2ReduceWindowTest, Add) {
ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec()); {input_arg.get()}, DefaultErrorSpec());
} }
};
TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
R2ReduceWindowTestInstantiation, R2ReduceWindowTest, R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
@ -1054,6 +1056,26 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(use_bfloat16_params)), ::testing::ValuesIn(use_bfloat16_params)),
R2ReduceWindowTestDataToString); R2ReduceWindowTestDataToString);
class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
DoIt();
}
const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = {
{/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128},
/*strides=*/{1, 1}, /*layout=*/{1, 0},
/*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
};
INSTANTIATE_TEST_CASE_P(
R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test,
::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test),
::testing::ValuesIn(use_bfloat16_params)),
R2ReduceWindowTestDataToString);
struct R1ReduceWindowTestData { struct R1ReduceWindowTestData {
int64 base_bounds[1]; int64 base_bounds[1];
int64 window_bounds[1]; int64 window_bounds[1];