Adds R1 test for ReduceWindow.
PiperOrigin-RevId: 183345779
This commit is contained in:
parent
e4912296bc
commit
ffda6079ed
@ -1016,37 +1016,39 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
|
|||||||
::testing::tuple<R2ReduceWindowTestData, bool>> {
|
::testing::tuple<R2ReduceWindowTestData, bool>> {
|
||||||
protected:
|
protected:
|
||||||
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
|
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
|
||||||
|
|
||||||
|
void DoIt() {
|
||||||
|
ComputationBuilder b(client_, TestName());
|
||||||
|
const auto& param = ::testing::get<0>(GetParam());
|
||||||
|
CHECK(param.reducer == kAdd);
|
||||||
|
|
||||||
|
const float kInitValue = 0.0f;
|
||||||
|
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
|
||||||
|
std::unique_ptr<Literal> input_literal =
|
||||||
|
Literal::CreateR2FromArray2DWithLayout(
|
||||||
|
input, LayoutUtil::MakeLayout(param.layout));
|
||||||
|
|
||||||
|
ComputationDataHandle parameter;
|
||||||
|
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
|
||||||
|
&b, ¶meter);
|
||||||
|
auto init_value =
|
||||||
|
CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
|
||||||
|
b.ReduceWindow(/*operand=*/parameter,
|
||||||
|
/*init_value=*/init_value,
|
||||||
|
/*computation=*/CreateScalarAddComputation(FloatType(), &b),
|
||||||
|
/*window_dimensions=*/param.window_bounds,
|
||||||
|
/*window_strides=*/param.strides, /*padding=*/param.padding);
|
||||||
|
|
||||||
|
auto expected = ReferenceUtil::ReduceWindow2DAdd(
|
||||||
|
/*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
|
||||||
|
/*stride=*/param.strides, /*padding=*/param.padding);
|
||||||
|
|
||||||
|
ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
|
||||||
|
{input_arg.get()}, DefaultErrorSpec());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(R2ReduceWindowTest, Add) {
|
TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
|
||||||
ComputationBuilder b(client_, TestName());
|
|
||||||
const auto& param = ::testing::get<0>(GetParam());
|
|
||||||
CHECK(param.reducer == kAdd);
|
|
||||||
|
|
||||||
const float kInitValue = 0.0f;
|
|
||||||
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
|
|
||||||
std::unique_ptr<Literal> input_literal =
|
|
||||||
Literal::CreateR2FromArray2DWithLayout(
|
|
||||||
input, LayoutUtil::MakeLayout(param.layout));
|
|
||||||
|
|
||||||
ComputationDataHandle parameter;
|
|
||||||
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
|
|
||||||
&b, ¶meter);
|
|
||||||
auto init_value =
|
|
||||||
CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
|
|
||||||
b.ReduceWindow(/*operand=*/parameter,
|
|
||||||
/*init_value=*/init_value,
|
|
||||||
/*computation=*/CreateScalarAddComputation(FloatType(), &b),
|
|
||||||
/*window_dimensions=*/param.window_bounds,
|
|
||||||
/*window_strides=*/param.strides, /*padding=*/param.padding);
|
|
||||||
|
|
||||||
auto expected = ReferenceUtil::ReduceWindow2DAdd(
|
|
||||||
/*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
|
|
||||||
/*stride=*/param.strides, /*padding=*/param.padding);
|
|
||||||
|
|
||||||
ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
|
|
||||||
{input_arg.get()}, DefaultErrorSpec());
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
|
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
|
||||||
@ -1054,6 +1056,26 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
::testing::ValuesIn(use_bfloat16_params)),
|
::testing::ValuesIn(use_bfloat16_params)),
|
||||||
R2ReduceWindowTestDataToString);
|
R2ReduceWindowTestDataToString);
|
||||||
|
|
||||||
|
class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
|
||||||
|
|
||||||
|
// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
|
||||||
|
XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
|
||||||
|
DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
|
||||||
|
DoIt();
|
||||||
|
}
|
||||||
|
|
||||||
|
const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = {
|
||||||
|
{/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128},
|
||||||
|
/*strides=*/{1, 1}, /*layout=*/{1, 0},
|
||||||
|
/*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test,
|
||||||
|
::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test),
|
||||||
|
::testing::ValuesIn(use_bfloat16_params)),
|
||||||
|
R2ReduceWindowTestDataToString);
|
||||||
|
|
||||||
struct R1ReduceWindowTestData {
|
struct R1ReduceWindowTestData {
|
||||||
int64 base_bounds[1];
|
int64 base_bounds[1];
|
||||||
int64 window_bounds[1];
|
int64 window_bounds[1];
|
||||||
|
Loading…
Reference in New Issue
Block a user