From 2e618439493cb3ff2dee1ff71c49129a493c0850 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Dec 2020 19:55:55 -0800 Subject: [PATCH] [XLA] Add implementation support for variadic reduce window, including HLO, cost analysis, etc. PiperOrigin-RevId: 347742450 Change-Id: I06b02b9407013322f3e72865fec487385a47abec --- .../xla/service/algebraic_simplifier.cc | 4 + .../compiler/xla/service/dynamic_padder.cc | 7 ++ .../compiler/xla/service/hlo_cost_analysis.cc | 5 +- .../xla/service/hlo_cost_analysis_test.cc | 35 ++++++ .../compiler/xla/service/hlo_instructions.cc | 8 +- tensorflow/compiler/xla/service/hlo_parser.cc | 19 ++- .../compiler/xla/service/hlo_parser_test.cc | 23 ++++ .../compiler/xla/service/hlo_verifier.cc | 1 + .../xla/service/sharding_propagation.cc | 15 ++- .../xla/service/space_to_batch_converter.cc | 5 + .../xla/service/spmd/spmd_partitioner.cc | 4 + .../compiler/xla/tests/reduce_window_test.cc | 109 ++++++++++++++++++ 12 files changed, 224 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 9a725cd541d..10e19e79596 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4696,6 +4696,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { + // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier. + if (reduce_window->shape().IsTuple()) { + return Status::OK(); + } if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) { return ReplaceWithNewInstruction( reduce_window, diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 167033b8f5f..ab94695c1e2 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -93,6 +93,9 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, return inst->mutable_operand(init_value_index); } case HloOpcode::kReduceWindow: { + if (inst->shape().IsTuple()) { + return Unimplemented("Variadic reduce window not yet supported. "); + } // Because of the way we do reduce, we already require the `init` // operand of hlo reduce instruction to be identity value. Here we reuse // the operand. @@ -1015,6 +1018,10 @@ StatusOr RewriteDynamicConvolutionKernelGrad( StatusOr RewriteDynamicReduceWindowSamePadding( HloInstruction* hlo, DynamicDimensionInference* dynamic_dimension_inference) { + if (hlo->shape().IsTuple()) { + // TODO (b/73062247) variadic reduce window is not yet supported here. + return Unimplemented("Variadic reduce window net yet supported."); + } HloInstruction* input = hlo->mutable_operand(0); HloInstruction* init = hlo->mutable_operand(1); HloComputation* comp = hlo->parent(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 939c713fc18..4ed89c4bfc9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -396,8 +396,11 @@ Status HloCostAnalysis::HandleReduceWindow( for (const auto& dimension : window.dimensions()) { window_element_count *= dimension.size(); } + const int64 output_element_count = - ShapeUtil::ElementsIn(reduce_window->shape()); + ShapeUtil::ElementsIn(reduce_window->shape().IsArray() + ? reduce_window->shape() + : reduce_window->shape().tuple_shapes(0)); const int64 reduction_count = (window_element_count - 1) * output_element_count; for (const auto& property : sub_properties) { diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 8f2b9a67790..748eb40b80b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -449,6 +449,41 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4); } +TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) { + XlaBuilder builder("reduce_window_variadic"); + auto elem_shape = ShapeUtil::MakeShape(F32, {}); + auto p2 = Parameter(&builder, 0, elem_shape, "x0"); + auto p3 = Parameter(&builder, 1, elem_shape, "x1"); + auto p4 = Parameter(&builder, 2, elem_shape, "y0"); + auto p5 = Parameter(&builder, 3, elem_shape, "y1"); + absl::InlinedVector compute_vec = {Min(p2, p4), Min(p3, p5)}; + Tuple(&builder, compute_vec); + TF_ASSERT_OK_AND_ASSIGN(auto compute_tuple, builder.Build()); + auto input1 = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input1"); + auto input2 = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {10, 20}), "input2"); + auto init = ConstantR0(&builder, 0); + ReduceWindow({input1, input2}, {init, init}, compute_tuple, {4, 5}, {4, 5}, + Padding::kValid); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Each of [2x4] output elements are generated from reducing [4x5] elements. + EXPECT_EQ(analysis.flop_count(), 2 * 4 * 2 * (4 * 5 - 1)); + + EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 * 2 + 2 * 3)); + + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 10 * 20); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20); + EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 4); +} + TEST_F(HloCostAnalysisTest, SelectAndScatter) { XlaBuilder builder("select_and_scatter"); auto operand = diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e43f68fd257..e203e63dee9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2285,9 +2285,13 @@ std::unique_ptr HloReduceWindowInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); + CHECK_EQ(new_operands.size() % 2, 0); + int64 num_operands = new_operands.size() / 2; return absl::make_unique( - shape, new_operands[0], new_operands[1], window(), to_apply()); + shape, absl::MakeSpan(new_operands).subspan(0, num_operands), + absl::MakeSpan(new_operands) + .subspan(num_operands, new_operands.size() / 2), + window(), to_apply()); } HloSelectAndScatterInstruction::HloSelectAndScatterInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 558a5029960..675b60b5453 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1366,16 +1366,25 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } if (!window) { window.emplace(); } + if (operands.size() % 2) { + auto loc = lexer_.GetLoc(); + return Error(loc, StrCat("expects an even number of operands, but has ", + operands.size(), " operands")); + } instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, - *reduce_computation)); + shape, /*operands=*/ + absl::Span(operands).subspan( + 0, operands.size() / 2), + /*init_values=*/ + absl::Span(operands).subspan(operands.size() / + 2), + *window, *reduce_computation)); break; } case HloOpcode::kConvolution: { @@ -3585,7 +3594,7 @@ bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) { } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. -// The string looks like "dim_labels=0bf_0io->0bf". +// Thestring looks like "dim_labels=0bf_0io->0bf". bool HloParserImpl::ParseConvolutionDimensionNumbers( ConvolutionDimensionNumbers* dnums) { if (lexer_.GetKind() != TokKind::kDimLabels) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index dc94e30c847..27b0de538b6 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -437,6 +437,29 @@ ENTRY %R4UnitWindowScalar () -> f32[] { ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3 } +)" +}, +// reduce window on scalar +{ +"ReduceWindowVariadic", +R"(HloModule reduce_window_variadic + +%add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) { + %lhs1 = f32[] parameter(0) + %rhs1 = f32[] parameter(2) + %add1 = f32[] add(f32[] %lhs1, f32[] %rhs1) + %lhs2 = f32[] parameter(1) + %rhs2 = f32[] parameter(3) + %add2 = f32[] add(f32[] %lhs2, f32[] %rhs2) + ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2) +} + +ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3 +} + )" }, // convolution diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 84e4fe6e3fd..80d2cd3e7b7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1086,6 +1086,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 6e5d77d067d..af5471e63a2 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -666,12 +666,13 @@ bool InferShardingFromOperands(HloInstruction* instruction, return false; } // Propagate manual sharding. Avoid tuple shaped HLOs that group independent - // together. Reduce and Sort can be tuples but the elements are correlated, so - // we propagate manual sharding through them. + // together. Reduce, ReduceWindow, and Sort can be tuples but the elements + // are correlated, so we propagate manual sharding through them. if (!instruction->has_sharding() && (instruction->shape().IsArray() || instruction->opcode() == HloOpcode::kReduce || - instruction->opcode() == HloOpcode::kSort) && + instruction->opcode() == HloOpcode::kSort || + instruction->opcode() == HloOpcode::kReduceWindow) && absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { return op->has_sharding() && op->sharding().IsManual(); })) { @@ -868,6 +869,10 @@ bool InferShardingFromOperands(HloInstruction* instruction, may_combine_partial_sharding); } case HloOpcode::kReduceWindow: { + if (instruction->shape().IsTuple()) { + // TODO (b/73062247) variadic reduce window is not yet supported here. + return false; + } const HloInstruction* lhs = instruction->operand(0); if (!IsSpatiallyPartitioned(lhs)) { return false; @@ -1292,6 +1297,10 @@ absl::optional GetShardingFromUser( return user.sharding(); } case HloOpcode::kReduceWindow: { + if (user.shape().IsTuple()) { + return user.sharding().GetSubSharding( + user.shape(), {user.operand_index(&instruction)}); + } if (&instruction != user.operand(0)) { return absl::nullopt; } diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index badd1bbeae4..8f7cc1af74a 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -922,6 +922,11 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, absl::c_linear_search(reduce_dims, space_dim); } + if (consumer->opcode() == HloOpcode::kReduceWindow && + consumer->shape().IsTuple()) { + // TODO (b/73062247) variadic reduce window is not yet supported. + return false; + } if (consumer->opcode() == HloOpcode::kReduceWindow || consumer->opcode() == HloOpcode::kSelectAndScatter) { auto first_operand = consumer->mutable_operand(0); diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index d541cee1c01..4064308aa40 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -3415,6 +3415,10 @@ Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { } Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { + // TODO(b/73062247) Variadic reduce window not yet supported in partitioner. + if (hlo->shape().IsTuple()) { + return DefaultAction(hlo); + } auto& operand = GetPartitionedHlo(hlo->operand(0)); if (hlo->sharding().IsTileMaximal()) { return DefaultAction(hlo); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 7e5b699d5e2..d86ebfa83cb 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -1704,5 +1704,114 @@ ENTRY R4OnlyDilation { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } +XLA_TEST_F(HloTestBase, + DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport))) { + const char* const hlo_string = R"( +HloModule module + +sum { + a0 = f32[] parameter(0) + a1 = f32[] parameter(1) + b0 = f32[] parameter(2) + b1 = f32[] parameter(3) + add0 = f32[] add(a0, b0) + add1 = f32[] add(a1, b1) + ROOT sum2 = (f32[], f32[]) tuple(add0, add1) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.1 = f32[] constant(0) + constant.2 = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.3 = f32[] constant(0) + reduce-window = (f32[2,2]{1,0}, f32[2,2]{1,0}) + reduce-window(constant, constant.2, constant.1, constant.3), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum + ROOT copy = (f32[2,2]{1,0}, f32[2,2]{1,0}) copy(reduce-window) +})"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +XLA_TEST_F(HloTestBase, + DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport2))) { + const char* const hlo_string = R"( +HloModule module + +sum { + a0 = f32[] parameter(0) + a1 = s32[] parameter(1) + b0 = f32[] parameter(2) + b1 = s32[] parameter(3) + add0 = f32[] add(a0, b0) + add1 = s32[] add(a1, b1) + ROOT sum2 = (f32[], s32[]) tuple(add0, add1) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.1 = f32[] constant(0) + constant.2 = s32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.3 = s32[] constant(0) + ROOT reduce-window = (f32[2,2]{1,0}, s32[2,2]{1,0}) + reduce-window(constant, constant.2, constant.1, constant.3), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum +})"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +XLA_TEST_F(HloTestBase, + DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport3))) { + const char* const hlo_string = R"( +HloModule module + +sum { + a0 = f32[] parameter(0) + a1 = bf16[] parameter(1) + b0 = f32[] parameter(2) + b1 = bf16[] parameter(3) + add0 = f32[] add(a0, b0) + add1 = bf16[] add(a1, b1) + ROOT sum2 = (f32[], bf16[]) tuple(add0, add1) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.1 = f32[] constant(0) + constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.3 = bf16[] constant(0) + ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0}) + reduce-window(constant, constant.2, constant.1, constant.3), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum +})"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + +XLA_TEST_F(HloTestBase, + DISABLED_ON_GPU(DISABLED_ON_CPU(ReduceWindowVariadicSupport4))) { + const char* const hlo_string = R"( +HloModule module + +sum { + a0 = f32[] parameter(0) + a1 = bf16[] parameter(1) + b0 = f32[] parameter(2) + b1 = bf16[] parameter(3) + add0 = f32[] add(a0, b0) + add1 = bf16[] multiply(a1, b1) + ROOT sum2 = (f32[], bf16[]) tuple(add0, add1) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.1 = f32[] constant(0) + constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}) + constant.3 = bf16[] constant(1) + ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0}) + reduce-window(constant, constant.2, constant.1, constant.3), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum +})"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4})); +} + } // namespace } // namespace xla