ReduceWindow test cleanup:

- Enable Add reduction on R4 tensors.
    - Remove redundant tests, to keep runtime reasonable.

PiperOrigin-RevId: 237310974
This commit is contained in:
Tayo Oguntebi 2019-03-07 13:24:44 -08:00 committed by TensorFlower Gardener
parent 7dd20b844c
commit 17baa62ff4

View File

@ -611,6 +611,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
// values. (Technically, the requirement is that the iota length is // values. (Technically, the requirement is that the iota length is
// relatively prime to all of the dimensions involved in the reduce-window.) // relatively prime to all of the dimensions involved in the reduce-window.)
input.FillRepeatedIota(0, 137); input.FillRepeatedIota(0, 137);
// Floating point sum reduction requires higher localized precision. We need
// the following normalization in order to enable testing of kAdd on large
// windows.
input.Each([&](absl::Span<const int64> /*indices*/, float* value) {
*value = *value / 10000000000.f;
});
Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout(param.layout)); input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter; XlaOp parameter;
@ -626,12 +632,6 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax); CHECK(param.reducer == kAdd || param.reducer == kMax);
auto reducer = param.reducer; auto reducer = param.reducer;
if (use_bfloat16()) {
// To avoid numerical issues, force the reducer to be kMax for bf16
// inputs.
reducer = kMax;
}
auto computation = reducer == kAdd auto computation = reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b) ? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b); : CreateScalarMaxComputation(FloatType(), &b);
@ -697,15 +697,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*layout=*/{3, 2, 1, 0}, /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd}, /*reducer=*/kAdd},
// With non-1x1 window.
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
/*window_bounds=*/{2, 3, 1, 1},
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
/*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With max instead of add. // With max instead of add.
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
/*window_bounds=*/{2, 3, 1, 1}, /*window_bounds=*/{2, 3, 1, 1},
@ -778,15 +769,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*layout=*/{3, 2, 1, 0}, /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd}, /*reducer=*/kAdd},
// With second minor dimension == 9.
R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127},
/*window_bounds=*/{1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
/*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With minor dimension == 129. // With minor dimension == 129.
R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129}, R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
/*window_bounds=*/{1, 1, 1, 1}, /*window_bounds=*/{1, 1, 1, 1},
@ -814,7 +796,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*layout=*/{3, 2, 1, 0}, /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd}, /*reducer=*/kAdd},
R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
/*window_bounds=*/{1, 64, 64, 1}, /*window_bounds=*/{1, 64, 64, 1},
/*strides=*/{1, 64, 64, 1}, /*strides=*/{1, 64, 64, 1},
/*pad_low=*/{0, 0, 0, 0}, /*pad_low=*/{0, 0, 0, 0},
@ -828,7 +810,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*pad_low=*/{0, 0, 0, 0}, /*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0},
/*layout=*/{3, 2, 1, 0}, /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd}, /*reducer=*/kMax},
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
/*window_bounds=*/{2, 3, 4, 5}, /*window_bounds=*/{2, 3, 4, 5},
@ -848,7 +830,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*reducer=*/kAdd}, /*reducer=*/kAdd},
// With 0123 layout. // With 0123 layout.
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23}, R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
/*window_bounds=*/{2, 3, 7, 9}, /*window_bounds=*/{2, 3, 7, 9},
/*strides=*/{1, 2, 5, 8}, /*strides=*/{1, 2, 5, 8},
/*pad_low=*/{0, 0, 0, 0}, /*pad_low=*/{0, 0, 0, 0},
@ -900,7 +882,6 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(use_bfloat16_params)), ::testing::ValuesIn(use_bfloat16_params)),
R4ReduceWindowTestDataToString); R4ReduceWindowTestDataToString);
struct R3ReduceWindowTestData { struct R3ReduceWindowTestData {
int64 base_bounds[3]; int64 base_bounds[3];
int64 window_bounds[3]; int64 window_bounds[3];