From 96bd0a0b6895d73fa2b6ab2bad374b22d8607f03 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 22 Oct 2020 00:13:09 -0700 Subject: [PATCH] [XLA] Add xla builder support for variadic reduce window op. This is the first CL leading to full support of variaduc reduce window. PiperOrigin-RevId: 338417352 Change-Id: I8e5907f0ddf2a29081c4d84d593b30f5c3eda6ed --- tensorflow/compiler/xla/client/xla_builder.cc | 81 ++++++++++++++++--- tensorflow/compiler/xla/client/xla_builder.h | 28 ++++++- .../compiler/xla/client/xla_builder_test.cc | 42 ++++++++++ tensorflow/compiler/xla/service/BUILD | 3 + .../xla/service/elemental_ir_emitter.cc | 4 + .../xla/service/hlo_evaluator_typed_visitor.h | 4 + .../compiler/xla/service/hlo_instruction.cc | 23 +++++- .../compiler/xla/service/hlo_instruction.h | 10 +++ .../compiler/xla/service/hlo_instructions.cc | 16 +++- .../compiler/xla/service/hlo_instructions.h | 35 ++++++++ tensorflow/compiler/xla/service/hlo_opcode.h | 2 +- .../compiler/xla/service/hlo_opcode_test.cc | 1 + .../compiler/xla/service/shape_inference.cc | 39 ++++++++- .../compiler/xla/service/shape_inference.h | 8 +- .../xla/service/shape_inference_test.cc | 50 ++++++++++++ 15 files changed, 322 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 41212e69b2e..b44673015bb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2324,31 +2324,53 @@ XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { - return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_RETURN_IF_ERROR( - ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()), - window_dimensions, window_strides)); + return ReduceWindow(absl::MakeSpan(&operand, 1), + absl::MakeSpan(&init_value, 1), computation, + window_dimensions, window_strides, padding); +} +XlaOp XlaBuilder::ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding) { + return ReportErrorOrReturn([&]() -> StatusOr { + const Shape* operand_shape = nullptr; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand)); + TF_RETURN_IF_ERROR( + ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()), + window_dimensions, window_strides)); + } + CHECK(operand_shape != nullptr); std::vector> padding_values = MakePadding(AsInt64Slice(operand_shape->dimensions()), window_dimensions, window_strides, padding); return ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, + operands, init_values, computation, window_dimensions, window_strides, /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding) { + std::vector operand_shapes, init_shapes; return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value)); + for (int i = 0; i < operands.size(); ++i) { + const auto& operand = operands[i]; + const auto& init_value = init_values[i]; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + operand_shapes.push_back(operand_shape); + TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value)); + init_shapes.push_back(init_shape); + } TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(auto window, @@ -2358,12 +2380,33 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferReduceWindowShape( - *operand_shape, *init_shape, window, to_apply_shape)); - return ReduceWindowInternal(shape, operand, init_value, computation, + absl::MakeSpan(operand_shapes), + absl::MakeSpan(init_shapes), window, to_apply_shape)); + return ReduceWindowInternal(shape, operands, init_values, computation, std::move(window)); }); } +StatusOr XlaBuilder::ReduceWindowInternal( + const Shape& shape, absl::Span operands, + absl::Span init_values, const XlaComputation& computation, + Window window) { + if (operands.size() == 1) { + return ReduceWindowInternal(shape, operands[0], init_values[0], computation, + window); + } else { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_window() = std::move(window); + AddCalledComputation(computation, &instr); + std::vector args; + args.insert(args.end(), operands.begin(), operands.end()); + args.insert(args.end(), init_values.begin(), init_values.end()); + return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, + absl::MakeSpan(args)); + } +} + StatusOr XlaBuilder::ReduceWindowInternal( const Shape& shape, XlaOp operand, XlaOp init_value, const XlaComputation& computation, Window window) { @@ -4067,6 +4110,17 @@ XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value, padding); } +XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { + CHECK(!operands.empty()); + return operands[0].builder()->ReduceWindow(operands, init_values, computation, + window_dimensions, window_strides, + padding); +} + XlaOp ReduceWindowWithGeneralPadding( const XlaOp operand, const XlaOp init_value, const XlaComputation& computation, @@ -4076,8 +4130,9 @@ XlaOp ReduceWindowWithGeneralPadding( absl::Span window_dilations, absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, - base_dilations, window_dilations, padding); + absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f736ae1d470..05efc038082 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -648,18 +648,28 @@ class XlaBuilder { absl::Span window_dimensions, absl::Span window_strides, Padding padding); + XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + XlaOp ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding); - + StatusOr ReduceWindowInternal(const Shape& shape, + absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + Window window); virtual StatusOr ReduceWindowInternal( const Shape& shape, XlaOp operand, XlaOp init_value, const XlaComputation& computation, Window window); - XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); @@ -1137,6 +1147,12 @@ class XlaBuilder { absl::Span window_dimensions, absl::Span window_strides, Padding padding); + friend XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, @@ -1965,6 +1981,12 @@ XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, absl::Span window_dimensions, absl::Span window_strides, Padding padding); +XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index bfd13c8ddf5..4fc6c848a38 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -873,6 +873,8 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) { ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, /*window_strides=*/{1, 1, 1}, Padding::kValid); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + VLOG(2) << module->entry_computation()->root_instruction()->ToString() + << "\n"; const Shape& result_shape = module->entry_computation()->root_instruction()->shape(); EXPECT_TRUE( @@ -880,6 +882,46 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) { << result_shape; } +TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto p1 = Parameter(&b, 1, tuple_param_shape, "p1"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p1, 0); + std::vector input_operands = {gte0, gte1}; + XlaBuilder bsum(TestName()); + auto p2 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x0"); + auto p3 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "x1"); + auto p4 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "y0"); + auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1"); + std::vector output_operands = {Add(p2, p4), Add(p3, p5)}; + Tuple(&bsum, absl::MakeSpan(output_operands)); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + auto init = ConstantR0(&b, 0.f); + ReduceWindow(input_operands, {init, init}, sum, + /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + VLOG(2) << module->entry_computation()->root_instruction()->ToString() + << "\n"; + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(0).dynamic_dimensions(), + {true, false, false})) + << result_shape.tuple_shapes(0); + EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(1).dynamic_dimensions(), + {true, false, false})) + << result_shape.tuple_shapes(1); +} + TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { XlaBuilder b(TestName()); Shape tuple_param_shape = ShapeUtil::MakeTupleShape( diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 491d1d67877..e16575bebd4 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -241,7 +241,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 98d523487b4..3a449b7c2db 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2540,6 +2540,10 @@ StatusOr ElementalIrEmitter::EmitElementalReduceWindow( // if I in bounds of input // value = function(value, input[I]) // output[O] = value + if (reduce_window->shape().IsTuple()) { + return Status(tensorflow::error::UNIMPLEMENTED, + "Variadic reduce window op is not yet fully supported."); + } const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 4fb7edd0104..4ddd8ce5146 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1932,6 +1932,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleReduceWindow(HloInstruction* reduce_window) override { + if (reduce_window->shape().IsTuple()) { + return Status(tensorflow::error::UNIMPLEMENTED, + "Variadic reduce window op is not yet fully supported."); + } auto operand = reduce_window->operand(0); const Window& window = reduce_window->window(); HloComputation* function = reduce_window->to_apply(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 41488dcdaaa..b24c35c4c69 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -515,11 +515,23 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kReduceWindow: + TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) + << "Reduce window should have an even number of operands but " + "sees " + << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduceWindow(shape, operands(0), operands(1), - proto.window(), computations(0)); + { + const auto reduce_operands = all_operands(); + auto inputs = absl::MakeSpan(reduce_operands) + .subspan(0, reduce_operands.size() / 2); + auto init_values = + absl::MakeSpan(reduce_operands) + .subspan(reduce_operands.size() / 2, reduce_operands.size()); + instruction = CreateReduceWindow(shape, inputs, init_values, + proto.window(), computations(0)); + } break; case HloOpcode::kSelectAndScatter: TF_RET_CHECK(proto.called_computation_ids_size() == 2) @@ -1273,6 +1285,13 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, shape, operand, init_value, window, reduce_computation); } +/* static */ std::unique_ptr HloInstruction::CreateReduceWindow( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation) { + return absl::make_unique( + shape, operands, init_values, window, reduce_computation); +} /* static */ std::unique_ptr HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 9675a2f0f0d..5901e446df5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -830,6 +830,16 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The reduce_computation being applied,now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // valueN], and returns an N-tuple. The operands and init_values now each + // contain a span of N input arrays and n initial values. + static std::unique_ptr CreateReduceWindow( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation); + // Creates a batch-norm-training instruction. static std::unique_ptr CreateBatchNormTraining( const Shape& shape, HloInstruction* operand, HloInstruction* scale, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 45b2d885d8e..8cb7d91f5ac 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2237,9 +2237,21 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloReduceWindowInstruction::HloReduceWindowInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) + : HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1), + absl::MakeSpan(&init_value, 1), window, + reduce_computation) {} + +HloReduceWindowInstruction::HloReduceWindowInstruction( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) { - AppendOperand(operand); - AppendOperand(init_value); + for (auto* operand : operands) { + AppendOperand(operand); + } + for (auto* init_value : init_values) { + AppendOperand(init_value); + } AppendComputation(reduce_computation); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 88e874347bd..848674fc604 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -1294,10 +1295,43 @@ class HloReduceWindowInstruction : public HloInstruction { HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + explicit HloReduceWindowInstruction( + const Shape& shape, absl::Span operands, + absl::Span init_values, const Window& window, + HloComputation* reduce_computation); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the number of input arrays (and, consequentially, the number of + // init values) this reduce has. + int64 input_count() const { return operand_count() / 2; } + // Returns the input tensors to be reduced. + absl::Span input_arrays() const { + return absl::MakeSpan(operands()).subspan(0, input_count()); + } + // Returns the init values of the reduction. + absl::Span init_values() const { + return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); + } + // Returns the shapes of input tensors to be reduced. + absl::InlinedVector input_array_shapes() const { + absl::InlinedVector shapes; + for (const auto* op : input_arrays()) { + VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n"; + shapes.push_back(&op->shape()); + VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n"; + } + return shapes; + } + // Returns the init values of the reduction. + absl::InlinedVector init_value_shapes() const { + absl::InlinedVector shapes; + for (const auto* op : init_values()) { + shapes.push_back(&op->shape()); + } + return shapes; + } private: std::vector ExtraAttributesToStringImpl( @@ -1310,6 +1344,7 @@ class HloReduceWindowInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; + Window window_; }; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index b50c7d9a584..e14d86e6bc0 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -119,7 +119,7 @@ namespace xla { V(kRecvDone, "recv-done", 1) \ V(kReduce, "reduce", kHloOpcodeIsVariadic) \ V(kReducePrecision, "reduce-precision", 1) \ - V(kReduceWindow, "reduce-window", 2) \ + V(kReduceWindow, "reduce-window", kHloOpcodeIsVariadic) \ V(kRemainder, "remainder", 2) \ V(kReplicaId, "replica-id", 0) \ V(kReshape, "reshape", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cceb60a70e9..95bb81c60f6 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -65,6 +65,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kRng: case HloOpcode::kSort: case HloOpcode::kTuple: + case HloOpcode::kReduceWindow: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; default: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a96c9c34260..43e3ea15b5f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2084,7 +2084,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, arg_shapes.size()); } int64 num_reduced_args = arg_shapes.size() / 2; - auto reduced_args = arg_shapes.subspan(0, num_reduced_args); // Check that all of the reduced tensors have the same dimensions. The element // types may be different. @@ -2097,7 +2096,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(*reduced_args[i])); } } - // Check that the dimensions to reduce are in-bounds for the given shape. // We've already verified all reduced tensors have the same dimensions, so it // doesn't matter which one we choose. @@ -2156,6 +2154,43 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferReduceWindowShape(operand_shape, init_value_shape, window); } +/* static */ StatusOr ShapeInference::InferReduceWindowShape( + absl::Span operands, absl::Span init_values, + const Window& window, const ProgramShape& to_apply_shape) { + auto number_of_input = operands.size(); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < number_of_input; ++i) { + if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) { + return InvalidArgument( + "All reduced tensors must have the same dimension. Tensor 0 has " + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*operands[0]), i, + ShapeUtil::HumanString(*operands[i])); + } + } + std::vector operand_element_type_vec; + for (const Shape* s : operands) { + operand_element_type_vec.push_back(s->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values, + operand_element_type_vec, + /*inputs=*/number_of_input)); + std::vector output_shape_vec; + for (int i = 0; i < operands.size(); ++i) { + TF_ASSIGN_OR_RETURN( + auto cur_output_shape, + InferReduceWindowShape(*operands[i], *init_values[i], window)); + output_shape_vec.push_back(cur_output_shape); + } + if (ShapeUtil::IsScalar(to_apply_shape.result())) { + CHECK_EQ(output_shape_vec.size(), 1); + return output_shape_vec[0]; + } else { + return ShapeUtil::MakeTupleShape(output_shape_vec); + } +} + /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index f03e4e5fa98..eb969873fd0 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -164,10 +164,16 @@ class ShapeInference { static StatusOr InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value, const Window& window, const ProgramShape& to_apply_shape); - static StatusOr InferReduceWindowShape(const Shape& operand_shape, const Shape& init_value, const Window& window); + static StatusOr InferReduceWindowShape( + absl::Span operands, absl::Span init_values, + const Window& window, const ProgramShape& to_apply_shape); + + static StatusOr InferReduceWindowShape( + absl::Span operands, absl::Span init_values, + const Window& window); // Infers the shape produced by scattering the given source shape to the // selected indices of each window on the operand shape. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 00ecb254a17..73bbe5cb3bd 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -912,6 +913,32 @@ TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { inferred_status.ValueOrDie())); } +TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); + std::vector args = {&f32_arg_shape, &s32_arg_shape}; + std::vector inits = {&f32_, &s32_}; + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + std::vector window_dimensions = {1, 2, 4}; + std::vector window_strides = {1, 1, 1}; + std::vector> padding_values = + MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions, + window_strides, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN( + Window window, + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding_values, {}, {})); + auto inferred_status = ShapeInference::InferReduceWindowShape( + absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); + VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n"; + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}), + ShapeUtil::MakeShape(S32, {5, 2, 0})}), + inferred_status.ValueOrDie())); +} + TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); @@ -948,6 +975,29 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { HasSubstr("must have at least 2 arguments, has 0")); } +TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); + std::vector args = {&f32_arg_shape, &s32_arg_shape}; + std::vector inits = {&f32_, &s32_}; + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + std::vector window_dimensions = {1, 2, 4}; + std::vector window_strides = {1, 1, 1}; + std::vector> padding_values = + MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions, + window_strides, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN( + Window window, + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding_values, {}, {})); + auto inferred_status = ShapeInference::InferReduceWindowShape( + absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); + EXPECT_FALSE(inferred_status.status().ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("f32[] vs s32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});