From 1e854a39efee806c836884a62583274f95eb4787 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 13 May 2019 08:55:12 -0700 Subject: [PATCH] [XLA] Don't use temporary memory in evaluator's implementation of iota. Fixes a fuzztest bug where we try to evaluate `f16[, 0] iota`. This is an empty array, but the old implementation would still try to do a huge allocation. PiperOrigin-RevId: 247943755 --- .../xla/service/hlo_evaluator_test.cc | 15 +++++++++ .../xla/service/hlo_evaluator_typed_visitor.h | 31 ++++--------------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 68221c036b9..c4266f95fcc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4341,5 +4341,20 @@ TEST_F(HloEvaluatorTest, IsFiniteBf16) { ::testing::ElementsAre(false, true, false, true, false, false)); } +// Check that evaluating `f32[, 0] iota` doesn't oom (it's an empty +// array!). +TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) { + constexpr absl::string_view hlo_text = R"( + HloModule test + ENTRY t { + ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), ::testing::IsEmpty()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index a2afb0c59eb..c3b5838cf0a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2505,32 +2505,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); - // Avoid using std::vector since std::vector does not convert to - // absl::Span. - absl::InlinedVector data(iota_size); - // We don't use std::iota for two reasons: - // - // (1) std:iota does not support bfloat16 and float16. - // - // (2) std::iota saturates for floating point types when the value is not - // representable, but the definition of HLO iota is the value as a - // 64-bit integer cast to the native type. - for (int64 i = 0; i < iota_size; ++i) { - // static_cast is required for Eigen::half (F16). - data[i] = static_cast(i); - } - auto result = LiteralUtil::CreateR1(data); - - if (iota->shape().rank() > 1) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[iota], - result.Broadcast(iota->shape(), {iota->iota_dimension()})); - } else { - TF_RET_CHECK(iota->shape().rank() == 1); - parent_->evaluated_[iota] = std::move(result); - } + Literal result(iota->shape()); + ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span idx) { + result.Set(idx, static_cast(idx[iota->iota_dimension()])); + return true; + }); + parent_->evaluated_[iota] = std::move(result); return Status::OK(); } template <