From ffda6079ed619df8fd3edb4db71ffc7d005c2430 Mon Sep 17 00:00:00 2001 From: Tayo Oguntebi Date: Thu, 25 Jan 2018 23:37:20 -0800 Subject: [PATCH] Adds R1 test for ReduceWindow. PiperOrigin-RevId: 183345779 --- .../compiler/xla/tests/reduce_window_test.cc | 80 ++++++++++++------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 73b37e201af..7f3c72671d5 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -1016,37 +1016,39 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ::testing::tuple> { protected: R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + + void DoIt() { + ComputationBuilder b(client_, TestName()); + const auto& param = ::testing::get<0>(GetParam()); + CHECK(param.reducer == kAdd); + + const float kInitValue = 0.0f; + Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow2DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); + } }; -TEST_P(R2ReduceWindowTest, Add) { - ComputationBuilder b(client_, TestName()); - const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); - - const float kInitValue = 0.0f; - Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - Literal::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); - - ComputationDataHandle parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); - auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); - - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); - - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); -} +TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); } INSTANTIATE_TEST_CASE_P( R2ReduceWindowTestInstantiation, R2ReduceWindowTest, @@ -1054,6 +1056,26 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(use_bfloat16_params)), R2ReduceWindowTestDataToString); +class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; + +// TODO(b/72234705): Fix the test cases failed on CPU and GPU. +XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DoIt(); +} + +const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { + {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test, + ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); + struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1];