diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 68221c036b9..c4266f95fcc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4341,5 +4341,20 @@ TEST_F(HloEvaluatorTest, IsFiniteBf16) { ::testing::ElementsAre(false, true, false, true, false, false)); } +// Check that evaluating `f32[, 0] iota` doesn't oom (it's an empty +// array!). +TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) { + constexpr absl::string_view hlo_text = R"( + HloModule test + ENTRY t { + ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), ::testing::IsEmpty()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index a2afb0c59eb..c3b5838cf0a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2505,32 +2505,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); - // Avoid using std::vector since std::vector does not convert to - // absl::Span. - absl::InlinedVector data(iota_size); - // We don't use std::iota for two reasons: - // - // (1) std:iota does not support bfloat16 and float16. - // - // (2) std::iota saturates for floating point types when the value is not - // representable, but the definition of HLO iota is the value as a - // 64-bit integer cast to the native type. - for (int64 i = 0; i < iota_size; ++i) { - // static_cast is required for Eigen::half (F16). - data[i] = static_cast(i); - } - auto result = LiteralUtil::CreateR1(data); - - if (iota->shape().rank() > 1) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[iota], - result.Broadcast(iota->shape(), {iota->iota_dimension()})); - } else { - TF_RET_CHECK(iota->shape().rank() == 1); - parent_->evaluated_[iota] = std::move(result); - } + Literal result(iota->shape()); + ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span idx) { + result.Set(idx, static_cast(idx[iota->iota_dimension()])); + return true; + }); + parent_->evaluated_[iota] = std::move(result); return Status::OK(); } template <