[XLA] Rollforward of Rolledback CL after fixing mistakes.
*** Reason for rollback of the rolled back CL***
Fix the mistake in the rolled-back CL, by changing a {0} subscript to {} in CopyElementFrom invocation.
PiperOrigin-RevId: 340492736
Change-Id: Idd4c8ab143bb10c6c9bbe876d7d4bb2747599553
			
			
This commit is contained in:
		
							parent
							
								
									cbc7f31b7d
								
							
						
					
					
						commit
						d7a91161c0
					
				| @ -2486,6 +2486,23 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) { | |||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Status HloEvaluator::HandleReduceWindow(HloInstruction* hlo) { | ||||||
|  |   // Here we delegate the handling to the typed visitor class, instantiated by
 | ||||||
|  |   // using the type of the first input of ReduceWindow. The support for the
 | ||||||
|  |   // variadic case inside the typed_visitor is made to not use the template
 | ||||||
|  |   // parameter so it doesn't really matter which type is used to instantiate it
 | ||||||
|  |   // here. We choose not to move the implementation for handle ReduceWindow
 | ||||||
|  |   // from the typed visitor to here because we need to reuse the
 | ||||||
|  |   // IterateThroughWindow method, which is defined and only avaiable inside the
 | ||||||
|  |   // typed visitor.
 | ||||||
|  |   if (hlo->shape().IsTuple()) { | ||||||
|  |     return hlo->Visit( | ||||||
|  |         typed_visitors_[hlo->shape().tuple_shapes(0).element_type()].get()); | ||||||
|  |   } else { | ||||||
|  |     return DefaultAction(hlo); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { | Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { | ||||||
|   if (!custom_call_handler_) { |   if (!custom_call_handler_) { | ||||||
|     // No handler is registered; this means custom-calls are not allowed.
 |     // No handler is registered; this means custom-calls are not allowed.
 | ||||||
|  | |||||||
| @ -260,6 +260,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { | |||||||
| 
 | 
 | ||||||
|   Status HandleReduce(HloInstruction* reduce) override; |   Status HandleReduce(HloInstruction* reduce) override; | ||||||
| 
 | 
 | ||||||
|  |   Status HandleReduceWindow(HloInstruction* hlo) override; | ||||||
|  | 
 | ||||||
|   Status HandleCustomCall(HloInstruction* custom_call) override; |   Status HandleCustomCall(HloInstruction* custom_call) override; | ||||||
| 
 | 
 | ||||||
|   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
 |   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
 | ||||||
|  | |||||||
| @ -2861,6 +2861,109 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { | |||||||
|   EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); |   EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) { | ||||||
|  |   HloComputation::Builder builder("main"); | ||||||
|  |   auto input1 = builder.AddInstruction(HloInstruction::CreateConstant( | ||||||
|  |       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}))); | ||||||
|  |   auto input2 = builder.AddInstruction(HloInstruction::CreateConstant( | ||||||
|  |       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}))); | ||||||
|  |   HloComputation::Builder bcompute("ComputeFunction"); | ||||||
|  |   auto shape1 = ShapeUtil::MakeShape(F32, {}); | ||||||
|  |   auto shape2 = ShapeUtil::MakeShape(F32, {}); | ||||||
|  |   auto p2 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0")); | ||||||
|  |   auto p3 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1")); | ||||||
|  |   auto p4 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0")); | ||||||
|  |   auto p5 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1")); | ||||||
|  |   std::vector<HloInstruction*> compute_vec = { | ||||||
|  |       bcompute.AddInstruction( | ||||||
|  |           HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)), | ||||||
|  |       bcompute.AddInstruction( | ||||||
|  |           HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))}; | ||||||
|  |   bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec)); | ||||||
|  |   auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build()); | ||||||
|  |   std::vector<HloInstruction*> input_vec = {input1, input2}; | ||||||
|  |   auto init1 = builder.AddInstruction( | ||||||
|  |       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32))); | ||||||
|  |   auto init2 = builder.AddInstruction( | ||||||
|  |       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32))); | ||||||
|  |   std::vector<HloInstruction*> init_vec = {init1, init2}; | ||||||
|  |   auto padding = std::pair<int64, int64>(0, 0); | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(auto window, | ||||||
|  |                           ShapeInference::InferWindowFromDimensions( | ||||||
|  |                               {3}, {2}, absl::MakeSpan(&padding, 1), | ||||||
|  |                               /*lhs_dilation=*/{}, | ||||||
|  |                               /*rhs_dilation=*/{})); | ||||||
|  |   std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()}; | ||||||
|  |   std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()}; | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(Shape shape, | ||||||
|  |                           ShapeInference::InferReduceWindowShape( | ||||||
|  |                               input_shapes, init_shapes, window, | ||||||
|  |                               compute_tuple->ComputeProgramShape())); | ||||||
|  |   builder.AddInstruction(HloInstruction::CreateReduceWindow( | ||||||
|  |       shape, input_vec, init_vec, window, compute_tuple)); | ||||||
|  |   auto r1 = LiteralUtil::CreateR1<float>({100, 1}); | ||||||
|  |   auto expected = LiteralUtil::MakeTuple({&r1, &r1}); | ||||||
|  |   m_->AddEntryComputation(builder.Build()); | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); | ||||||
|  |   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) { | ||||||
|  |   HloComputation::Builder builder("main"); | ||||||
|  |   auto input1 = builder.AddInstruction(HloInstruction::CreateConstant( | ||||||
|  |       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}))); | ||||||
|  |   auto input2 = builder.AddInstruction(HloInstruction::CreateConstant( | ||||||
|  |       LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12}))); | ||||||
|  |   HloComputation::Builder bcompute("ComputeFunction"); | ||||||
|  |   auto shape1 = ShapeUtil::MakeShape(F32, {}); | ||||||
|  |   auto shape2 = ShapeUtil::MakeShape(S32, {}); | ||||||
|  |   auto p2 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0")); | ||||||
|  |   auto p3 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1")); | ||||||
|  |   auto p4 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0")); | ||||||
|  |   auto p5 = | ||||||
|  |       bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1")); | ||||||
|  |   std::vector<HloInstruction*> compute_vec = { | ||||||
|  |       bcompute.AddInstruction( | ||||||
|  |           HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)), | ||||||
|  |       bcompute.AddInstruction( | ||||||
|  |           HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))}; | ||||||
|  |   bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec)); | ||||||
|  |   auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build()); | ||||||
|  |   std::vector<HloInstruction*> input_vec = {input1, input2}; | ||||||
|  |   auto init1 = builder.AddInstruction( | ||||||
|  |       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32))); | ||||||
|  |   auto init2 = builder.AddInstruction( | ||||||
|  |       HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32))); | ||||||
|  |   std::vector<HloInstruction*> init_vec = {init1, init2}; | ||||||
|  |   auto padding = std::pair<int64, int64>(0, 0); | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(auto window, | ||||||
|  |                           ShapeInference::InferWindowFromDimensions( | ||||||
|  |                               {3}, {2}, absl::MakeSpan(&padding, 1), | ||||||
|  |                               /*lhs_dilation=*/{}, | ||||||
|  |                               /*rhs_dilation=*/{})); | ||||||
|  |   std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()}; | ||||||
|  |   std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()}; | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(Shape shape, | ||||||
|  |                           ShapeInference::InferReduceWindowShape( | ||||||
|  |                               input_shapes, init_shapes, window, | ||||||
|  |                               compute_tuple->ComputeProgramShape())); | ||||||
|  |   builder.AddInstruction(HloInstruction::CreateReduceWindow( | ||||||
|  |       shape, input_vec, init_vec, window, compute_tuple)); | ||||||
|  |   auto r1 = LiteralUtil::CreateR1<float>({100, 1}); | ||||||
|  |   auto r2 = LiteralUtil::CreateR1<int>({15, 12}); | ||||||
|  |   auto expected = LiteralUtil::MakeTuple({&r1, &r2}); | ||||||
|  |   m_->AddEntryComputation(builder.Build()); | ||||||
|  |   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); | ||||||
|  |   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_P(HloEvaluatorBf16Test, StridedSlice) { | TEST_P(HloEvaluatorBf16Test, StridedSlice) { | ||||||
|   HloComputation::Builder b(TestName()); |   HloComputation::Builder b(TestName()); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -26,6 +26,7 @@ limitations under the License. | |||||||
| #include "absl/memory/memory.h" | #include "absl/memory/memory.h" | ||||||
| #include "absl/meta/type_traits.h" | #include "absl/meta/type_traits.h" | ||||||
| #include "absl/types/optional.h" | #include "absl/types/optional.h" | ||||||
|  | #include "absl/types/span.h" | ||||||
| #include "tensorflow/compiler/xla/array2d.h" | #include "tensorflow/compiler/xla/array2d.h" | ||||||
| #include "tensorflow/compiler/xla/literal.h" | #include "tensorflow/compiler/xla/literal.h" | ||||||
| #include "tensorflow/compiler/xla/literal_util.h" | #include "tensorflow/compiler/xla/literal_util.h" | ||||||
| @ -664,6 +665,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { | |||||||
|             typename std::enable_if<std::is_integral<NativeT>::value>::type* = |             typename std::enable_if<std::is_integral<NativeT>::value>::type* = | ||||||
|                 nullptr> |                 nullptr> | ||||||
|   Status HandleMinimum(HloInstruction* minimum) { |   Status HandleMinimum(HloInstruction* minimum) { | ||||||
|  |     VLOG(2) << "Evaluating minimum\n"; | ||||||
|     TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], |     TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], | ||||||
|                         ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, |                         ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, | ||||||
|                                                         ElementwiseT rhs_el) { |                                                         ElementwiseT rhs_el) { | ||||||
| @ -1932,18 +1934,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   Status HandleReduceWindow(HloInstruction* reduce_window) override { |   Status HandleReduceWindow(HloInstruction* reduce_window) override { | ||||||
|     if (reduce_window->shape().IsTuple()) { |     auto* reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window); | ||||||
|       return Status(tensorflow::error::UNIMPLEMENTED, |  | ||||||
|                     "Variadic reduce window op is not yet fully supported."); |  | ||||||
|     } |  | ||||||
|     auto operand = reduce_window->operand(0); |  | ||||||
|     const Window& window = reduce_window->window(); |     const Window& window = reduce_window->window(); | ||||||
|     HloComputation* function = reduce_window->to_apply(); |     HloComputation* function = reduce_window->to_apply(); | ||||||
|     TF_ASSIGN_OR_RETURN( |     TF_ASSIGN_OR_RETURN( | ||||||
|         auto inferred_return_shape, |         auto inferred_return_shape, | ||||||
|         ShapeInference::InferReduceWindowShape( |         ShapeInference::InferReduceWindowShape( | ||||||
|             /*operand_shape=*/reduce_window->operand(0)->shape(), |             reduce_window_instr->input_array_shapes(), | ||||||
|             /*init_value=*/reduce_window->operand(1)->shape(), window, |             reduce_window_instr->init_value_shapes(), window, | ||||||
|             /*to_apply_shape=*/function->ComputeProgramShape())); |             /*to_apply_shape=*/function->ComputeProgramShape())); | ||||||
|     TF_RET_CHECK( |     TF_RET_CHECK( | ||||||
|         ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) |         ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) | ||||||
| @ -1952,62 +1950,101 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { | |||||||
|         << " but is inferred to be: " |         << " but is inferred to be: " | ||||||
|         << ShapeUtil::HumanStringWithLayout(inferred_return_shape); |         << ShapeUtil::HumanStringWithLayout(inferred_return_shape); | ||||||
| 
 | 
 | ||||||
|     const Literal& operand_literal = |     absl::InlinedVector<const Literal*, 2> input_literal_vec, init_literal_vec; | ||||||
|         parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); |     auto input_arrays = reduce_window_instr->input_arrays(); | ||||||
|     VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); |     auto init_values = reduce_window_instr->init_values(); | ||||||
|     const Literal& init_literal = |     int64 num_args = input_arrays.size(); | ||||||
|         parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); |     for (int i = 0; i < num_args; ++i) { | ||||||
|     VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); |       const Literal& input_literal = | ||||||
|     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); |           parent_->GetEvaluatedLiteralFor(input_arrays[i]); | ||||||
|     auto init_scalar = init_literal.Get<ReturnT>({}); |       VLOG(3) << "HandleReduceWindow arg_literal: " << input_literal.ToString(); | ||||||
| 
 |       input_literal_vec.push_back(&input_literal); | ||||||
|  |       const Literal& init_literal = | ||||||
|  |           parent_->GetEvaluatedLiteralFor(init_values[i]); | ||||||
|  |       VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); | ||||||
|  |       TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); | ||||||
|  |       init_literal_vec.push_back(&init_literal); | ||||||
|  |     } | ||||||
|     // Creates a Shape object from window, for iteration below.
 |     // Creates a Shape object from window, for iteration below.
 | ||||||
|     std::vector<int64> window_dimension_sizes; |     absl::InlinedVector<int64, 2> window_dimension_sizes; | ||||||
|     for (const auto& window_dimension : window.dimensions()) { |     for (const auto& window_dimension : window.dimensions()) { | ||||||
|       window_dimension_sizes.push_back(window_dimension.size()); |       window_dimension_sizes.push_back(window_dimension.size()); | ||||||
|     } |     } | ||||||
|     const Shape window_shape = ShapeUtil::MakeShape( |     const Shape window_shape = ShapeUtil::MakeShape( | ||||||
|         operand->shape().element_type(), window_dimension_sizes); |         input_arrays[0]->shape().element_type(), window_dimension_sizes); | ||||||
| 
 |  | ||||||
|     DimensionVector window_index(window.dimensions_size()); |  | ||||||
|     DimensionVector operand_index(operand_literal.shape().rank()); |  | ||||||
| 
 | 
 | ||||||
|     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); |     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); | ||||||
|     Literal result(reduce_window->shape()); |  | ||||||
|     // For each resulting dimension, calculate and assign computed value.
 |     // For each resulting dimension, calculate and assign computed value.
 | ||||||
|     TF_RETURN_IF_ERROR( |     auto evaluate_impl = | ||||||
|         result.Populate<ReturnT>([&](absl::Span<const int64> output_index) { |         [&](absl::Span<const int64> output_index) -> std::vector<Literal> { | ||||||
|           ReturnT result_val = init_scalar; |       std::vector<Literal> computed_result; | ||||||
| 
 |       computed_result.reserve(init_literal_vec.size()); | ||||||
|           std::fill(window_index.begin(), window_index.end(), 0); |       for (const auto* init : init_literal_vec) { | ||||||
|           std::fill(operand_index.begin(), operand_index.end(), 0); |         computed_result.push_back(init->Clone()); | ||||||
| 
 |       } | ||||||
|           IterateThroughWindow( |       IterateThroughWindow( | ||||||
|               window_shape, window, operand_literal.shape(), output_index, |           window_shape, window, input_literal_vec[0]->shape(), output_index, | ||||||
|               [&](const std::vector<int64>& operand_index) { |           [&](absl::Span<const int64> operand_index) -> void { | ||||||
|                 auto curr_val = operand_literal.Get<ReturnT>(operand_index); |             absl::InlinedVector<const Literal*, 2> args; | ||||||
| 
 |             for (auto& curr_result_val : computed_result) { | ||||||
|                 // Evaluate computation with specified literal operands.
 |               VLOG(2) << "Pushing:" << curr_result_val.ToString() << "\n"; | ||||||
|                 const auto curr_val_literal = |               args.push_back(&curr_result_val); | ||||||
|                     LiteralUtil::CreateR0<ReturnT>(curr_val); |             } | ||||||
|                 const auto result_val_literal = |             absl::InlinedVector<Literal, 2> curr_val_literal_vec( | ||||||
|                     LiteralUtil::CreateR0<ReturnT>(result_val); |                 input_literal_vec.size()); | ||||||
|                 Literal computed_result = |             for (const auto* input_literal : input_literal_vec) { | ||||||
|                     embedded_evaluator |               // Evaluate computation with specified literal operands.
 | ||||||
|                         .Evaluate(*function, |               curr_val_literal_vec.push_back(Literal(ShapeUtil::MakeShape( | ||||||
|                                   {&result_val_literal, &curr_val_literal}) |                   input_literal->shape().element_type(), {}))); | ||||||
|                         .ConsumeValueOrDie(); |               TF_CHECK_OK(curr_val_literal_vec.back().CopyElementFrom( | ||||||
| 
 |                   *input_literal, operand_index, {})); | ||||||
|                 // Clear visit states so that the we can use the evaluate again
 |               VLOG(2) << "Pushing:" << curr_val_literal_vec.back().ToString() | ||||||
|                 // on the same computation.
 |                       << "\n"; | ||||||
|                 embedded_evaluator.ResetVisitStates(); |               args.push_back(&curr_val_literal_vec.back()); | ||||||
| 
 |             } | ||||||
|                 result_val = computed_result.Get<ReturnT>({}); |             computed_result[0] = embedded_evaluator.Evaluate(*function, args) | ||||||
|               }); |                                      .ConsumeValueOrDie(); | ||||||
| 
 |             VLOG(2) << "Computed result:" << computed_result[0].ToString() | ||||||
|           return result_val; |                     << "\n"; | ||||||
|         })); |             // Clear visit states so that the we can use the evaluate again
 | ||||||
| 
 |             // on the same computation.
 | ||||||
|  |             embedded_evaluator.ResetVisitStates(); | ||||||
|  |             if (inferred_return_shape.IsTuple()) { | ||||||
|  |               computed_result = computed_result[0].DecomposeTuple(); | ||||||
|  |             } | ||||||
|  |           }); | ||||||
|  |       VLOG(2) << "Final result size:" << computed_result.size() << "\n"; | ||||||
|  |       for (const auto& res : computed_result) { | ||||||
|  |         VLOG(2) << res.ToString() << "\n"; | ||||||
|  |       } | ||||||
|  |       return computed_result; | ||||||
|  |     }; | ||||||
|  |     Literal result(inferred_return_shape); | ||||||
|  |     if (inferred_return_shape.IsTuple()) { | ||||||
|  |       absl::InlinedVector<Literal, 1> results(num_args); | ||||||
|  |       for (int64 i = 0; i < num_args; ++i) { | ||||||
|  |         results[i] = Literal(inferred_return_shape.tuple_shapes(i)); | ||||||
|  |       } | ||||||
|  |       TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( | ||||||
|  |           inferred_return_shape.tuple_shapes(0), | ||||||
|  |           [&](absl::Span<const int64> output_index) -> StatusOr<bool> { | ||||||
|  |             std::vector<Literal> computed_result_vec = | ||||||
|  |                 evaluate_impl(output_index); | ||||||
|  |             for (int i = 0; i < computed_result_vec.size(); ++i) { | ||||||
|  |               TF_RETURN_IF_ERROR(results[i].CopyElementFrom( | ||||||
|  |                   computed_result_vec[i], {}, output_index)); | ||||||
|  |             } | ||||||
|  |             return true; | ||||||
|  |           })); | ||||||
|  |       result = Literal::MoveIntoTuple(absl::MakeSpan(results)); | ||||||
|  |       VLOG(2) << "Final result is:" << result.ToString() << "\n"; | ||||||
|  |     } else { | ||||||
|  |       TF_RETURN_IF_ERROR( | ||||||
|  |           result.Populate<ReturnT>([&](absl::Span<const int64> output_index) { | ||||||
|  |             return evaluate_impl(output_index)[0].template Get<ReturnT>({}); | ||||||
|  |           })); | ||||||
|  |     } | ||||||
|  |     VLOG(2) << "Final result is:" << result.ToString() << "\n"; | ||||||
|     parent_->evaluated_[reduce_window] = std::move(result); |     parent_->evaluated_[reduce_window] = std::move(result); | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -881,12 +881,16 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { | Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { | ||||||
|  |   VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n"; | ||||||
|  |   auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window); | ||||||
|  |   auto input_shapes = reduce_window_instr->input_array_shapes(); | ||||||
|  |   VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n"; | ||||||
|  |   auto init_shapes = reduce_window_instr->init_value_shapes(); | ||||||
|  |   VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n"; | ||||||
|   TF_RETURN_IF_ERROR(CheckShape( |   TF_RETURN_IF_ERROR(CheckShape( | ||||||
|       reduce_window, |       reduce_window, ShapeInference::InferReduceWindowShape( | ||||||
|       ShapeInference::InferReduceWindowShape( |                          input_shapes, init_shapes, reduce_window->window(), | ||||||
|           reduce_window->operand(0)->shape(), |                          reduce_window->to_apply()->ComputeProgramShape()))); | ||||||
|           reduce_window->operand(1)->shape(), reduce_window->window(), |  | ||||||
|           reduce_window->to_apply()->ComputeProgramShape()))); |  | ||||||
| 
 | 
 | ||||||
|   return allow_mixed_precision_ |   return allow_mixed_precision_ | ||||||
|              ? Status::OK() |              ? Status::OK() | ||||||
|  | |||||||
| @ -2168,8 +2168,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape( | /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape( | ||||||
|     absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, |     absl::Span<const Shape* const> operands, | ||||||
|     const Window& window, const ProgramShape& to_apply_shape) { |     absl::Span<const Shape* const> init_values, const Window& window, | ||||||
|  |     const ProgramShape& to_apply_shape) { | ||||||
|   auto number_of_input = operands.size(); |   auto number_of_input = operands.size(); | ||||||
|   // Check that all of the reduced tensors have the same dimensions. The element
 |   // Check that all of the reduced tensors have the same dimensions. The element
 | ||||||
|   // types may be different.
 |   // types may be different.
 | ||||||
|  | |||||||
| @ -168,8 +168,9 @@ class ShapeInference { | |||||||
|                                                 const Shape& init_value, |                                                 const Shape& init_value, | ||||||
|                                                 const Window& window); |                                                 const Window& window); | ||||||
|   static StatusOr<Shape> InferReduceWindowShape( |   static StatusOr<Shape> InferReduceWindowShape( | ||||||
|       absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, |       absl::Span<const Shape* const> operands, | ||||||
|       const Window& window, const ProgramShape& to_apply_shape); |       absl::Span<const Shape* const> init_values, const Window& window, | ||||||
|  |       const ProgramShape& to_apply_shape); | ||||||
| 
 | 
 | ||||||
|   static StatusOr<Shape> InferReduceWindowShape( |   static StatusOr<Shape> InferReduceWindowShape( | ||||||
|       absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, |       absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user