ReduceWindow test cleanup:
- Enable Add reduction on R4 tensors. - Remove redundant tests, to keep runtime reasonable. PiperOrigin-RevId: 237310974
This commit is contained in:
parent
7dd20b844c
commit
17baa62ff4
@ -611,6 +611,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
|
|||||||
// values. (Technically, the requirement is that the iota length is
|
// values. (Technically, the requirement is that the iota length is
|
||||||
// relatively prime to all of the dimensions involved in the reduce-window.)
|
// relatively prime to all of the dimensions involved in the reduce-window.)
|
||||||
input.FillRepeatedIota(0, 137);
|
input.FillRepeatedIota(0, 137);
|
||||||
|
// Floating point sum reduction requires higher localized precision. We need
|
||||||
|
// the following normalization in order to enable testing of kAdd on large
|
||||||
|
// windows.
|
||||||
|
input.Each([&](absl::Span<const int64> /*indices*/, float* value) {
|
||||||
|
*value = *value / 10000000000.f;
|
||||||
|
});
|
||||||
Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
|
Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
|
||||||
input, LayoutUtil::MakeLayout(param.layout));
|
input, LayoutUtil::MakeLayout(param.layout));
|
||||||
XlaOp parameter;
|
XlaOp parameter;
|
||||||
@ -626,12 +632,6 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
|
|||||||
CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
|
CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
|
||||||
CHECK(param.reducer == kAdd || param.reducer == kMax);
|
CHECK(param.reducer == kAdd || param.reducer == kMax);
|
||||||
auto reducer = param.reducer;
|
auto reducer = param.reducer;
|
||||||
if (use_bfloat16()) {
|
|
||||||
// To avoid numerical issues, force the reducer to be kMax for bf16
|
|
||||||
// inputs.
|
|
||||||
reducer = kMax;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto computation = reducer == kAdd
|
auto computation = reducer == kAdd
|
||||||
? CreateScalarAddComputation(FloatType(), &b)
|
? CreateScalarAddComputation(FloatType(), &b)
|
||||||
: CreateScalarMaxComputation(FloatType(), &b);
|
: CreateScalarMaxComputation(FloatType(), &b);
|
||||||
@ -697,15 +697,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
|
|||||||
/*layout=*/{3, 2, 1, 0},
|
/*layout=*/{3, 2, 1, 0},
|
||||||
/*reducer=*/kAdd},
|
/*reducer=*/kAdd},
|
||||||
|
|
||||||
// With non-1x1 window.
|
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
|
|
||||||
/*window_bounds=*/{2, 3, 1, 1},
|
|
||||||
/*strides=*/{1, 1, 1, 1},
|
|
||||||
/*pad_low=*/{0, 0, 0, 0},
|
|
||||||
/*pad_high=*/{0, 0, 0, 0},
|
|
||||||
/*layout=*/{3, 2, 1, 0},
|
|
||||||
/*reducer=*/kAdd},
|
|
||||||
|
|
||||||
// With max instead of add.
|
// With max instead of add.
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
|
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
|
||||||
/*window_bounds=*/{2, 3, 1, 1},
|
/*window_bounds=*/{2, 3, 1, 1},
|
||||||
@ -778,15 +769,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
|
|||||||
/*layout=*/{3, 2, 1, 0},
|
/*layout=*/{3, 2, 1, 0},
|
||||||
/*reducer=*/kAdd},
|
/*reducer=*/kAdd},
|
||||||
|
|
||||||
// With second minor dimension == 9.
|
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127},
|
|
||||||
/*window_bounds=*/{1, 1, 1, 1},
|
|
||||||
/*strides=*/{1, 1, 1, 1},
|
|
||||||
/*pad_low=*/{0, 0, 0, 0},
|
|
||||||
/*pad_high=*/{0, 0, 0, 0},
|
|
||||||
/*layout=*/{3, 2, 1, 0},
|
|
||||||
/*reducer=*/kAdd},
|
|
||||||
|
|
||||||
// With minor dimension == 129.
|
// With minor dimension == 129.
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
|
R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
|
||||||
/*window_bounds=*/{1, 1, 1, 1},
|
/*window_bounds=*/{1, 1, 1, 1},
|
||||||
@ -814,7 +796,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
|
|||||||
/*layout=*/{3, 2, 1, 0},
|
/*layout=*/{3, 2, 1, 0},
|
||||||
/*reducer=*/kAdd},
|
/*reducer=*/kAdd},
|
||||||
|
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3},
|
R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
|
||||||
/*window_bounds=*/{1, 64, 64, 1},
|
/*window_bounds=*/{1, 64, 64, 1},
|
||||||
/*strides=*/{1, 64, 64, 1},
|
/*strides=*/{1, 64, 64, 1},
|
||||||
/*pad_low=*/{0, 0, 0, 0},
|
/*pad_low=*/{0, 0, 0, 0},
|
||||||
@ -828,7 +810,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
|
|||||||
/*pad_low=*/{0, 0, 0, 0},
|
/*pad_low=*/{0, 0, 0, 0},
|
||||||
/*pad_high=*/{0, 0, 0, 0},
|
/*pad_high=*/{0, 0, 0, 0},
|
||||||
/*layout=*/{3, 2, 1, 0},
|
/*layout=*/{3, 2, 1, 0},
|
||||||
/*reducer=*/kAdd},
|
/*reducer=*/kMax},
|
||||||
|
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
|
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
|
||||||
/*window_bounds=*/{2, 3, 4, 5},
|
/*window_bounds=*/{2, 3, 4, 5},
|
||||||
@ -848,7 +830,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
|
|||||||
/*reducer=*/kAdd},
|
/*reducer=*/kAdd},
|
||||||
|
|
||||||
// With 0123 layout.
|
// With 0123 layout.
|
||||||
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23},
|
R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
|
||||||
/*window_bounds=*/{2, 3, 7, 9},
|
/*window_bounds=*/{2, 3, 7, 9},
|
||||||
/*strides=*/{1, 2, 5, 8},
|
/*strides=*/{1, 2, 5, 8},
|
||||||
/*pad_low=*/{0, 0, 0, 0},
|
/*pad_low=*/{0, 0, 0, 0},
|
||||||
@ -900,7 +882,6 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
::testing::ValuesIn(use_bfloat16_params)),
|
::testing::ValuesIn(use_bfloat16_params)),
|
||||||
R4ReduceWindowTestDataToString);
|
R4ReduceWindowTestDataToString);
|
||||||
|
|
||||||
|
|
||||||
struct R3ReduceWindowTestData {
|
struct R3ReduceWindowTestData {
|
||||||
int64 base_bounds[3];
|
int64 base_bounds[3];
|
||||||
int64 window_bounds[3];
|
int64 window_bounds[3];
|
||||||
|
Loading…
Reference in New Issue
Block a user