[XLA] Don't use temporary memory in evaluator's implementation of iota.

Fixes a fuzztest bug where we try to evaluate `f16[<huge>, 0] iota`.  This is
an empty array, but the old implementation would still try to do a huge
allocation.

PiperOrigin-RevId: 247943755
This commit is contained in:
Justin Lebar 2019-05-13 08:55:12 -07:00 committed by TensorFlower Gardener
parent 19872a69a9
commit 1e854a39ef
2 changed files with 21 additions and 25 deletions

View File

@ -4341,5 +4341,20 @@ TEST_F(HloEvaluatorTest, IsFiniteBf16) {
::testing::ElementsAre(false, true, false, true, false, false));
}
// Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
// array!).
TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
constexpr absl::string_view hlo_text = R"(
HloModule test
ENTRY t {
ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
})";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(
Literal actual_literal,
HloEvaluator().Evaluate(*m_->entry_computation(), {}));
EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
}
} // namespace
} // namespace xla

View File

@ -2505,32 +2505,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
const int64 iota_size = iota->shape().dimensions(iota->iota_dimension());
// Avoid using std::vector since std::vector<bool> does not convert to
// absl::Span<bool>.
absl::InlinedVector<NativeT, 1> data(iota_size);
// We don't use std::iota for two reasons:
//
// (1) std:iota does not support bfloat16 and float16.
//
// (2) std::iota saturates for floating point types when the value is not
// representable, but the definition of HLO iota is the value as a
// 64-bit integer cast to the native type.
for (int64 i = 0; i < iota_size; ++i) {
// static_cast is required for Eigen::half (F16).
data[i] = static_cast<NativeT>(i);
}
auto result = LiteralUtil::CreateR1<NativeT>(data);
if (iota->shape().rank() > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(iota->shape().rank() == 1);
parent_->evaluated_[iota] = std::move(result);
}
Literal result(iota->shape());
ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span<const int64> idx) {
result.Set(idx, static_cast<NativeT>(idx[iota->iota_dimension()]));
return true;
});
parent_->evaluated_[iota] = std::move(result);
return Status::OK();
}
template <