[XLA] Don't use temporary memory in evaluator's implementation of iota.
Fixes a fuzztest bug where we try to evaluate `f16[<huge>, 0] iota`. This is an empty array, but the old implementation would still try to do a huge allocation. PiperOrigin-RevId: 247943755
This commit is contained in:
parent
19872a69a9
commit
1e854a39ef
@ -4341,5 +4341,20 @@ TEST_F(HloEvaluatorTest, IsFiniteBf16) {
|
||||
::testing::ElementsAre(false, true, false, true, false, false));
|
||||
}
|
||||
|
||||
// Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
|
||||
// array!).
|
||||
TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
|
||||
constexpr absl::string_view hlo_text = R"(
|
||||
HloModule test
|
||||
ENTRY t {
|
||||
ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal actual_literal,
|
||||
HloEvaluator().Evaluate(*m_->entry_computation(), {}));
|
||||
EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -2505,32 +2505,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||
Status HandleIota(HloInstruction* instruction) {
|
||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||
const int64 iota_size = iota->shape().dimensions(iota->iota_dimension());
|
||||
// Avoid using std::vector since std::vector<bool> does not convert to
|
||||
// absl::Span<bool>.
|
||||
absl::InlinedVector<NativeT, 1> data(iota_size);
|
||||
// We don't use std::iota for two reasons:
|
||||
//
|
||||
// (1) std:iota does not support bfloat16 and float16.
|
||||
//
|
||||
// (2) std::iota saturates for floating point types when the value is not
|
||||
// representable, but the definition of HLO iota is the value as a
|
||||
// 64-bit integer cast to the native type.
|
||||
for (int64 i = 0; i < iota_size; ++i) {
|
||||
// static_cast is required for Eigen::half (F16).
|
||||
data[i] = static_cast<NativeT>(i);
|
||||
}
|
||||
auto result = LiteralUtil::CreateR1<NativeT>(data);
|
||||
|
||||
if (iota->shape().rank() > 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[iota],
|
||||
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
|
||||
} else {
|
||||
TF_RET_CHECK(iota->shape().rank() == 1);
|
||||
parent_->evaluated_[iota] = std::move(result);
|
||||
}
|
||||
|
||||
Literal result(iota->shape());
|
||||
ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span<const int64> idx) {
|
||||
result.Set(idx, static_cast<NativeT>(idx[iota->iota_dimension()]));
|
||||
return true;
|
||||
});
|
||||
parent_->evaluated_[iota] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
template <
|
||||
|
Loading…
Reference in New Issue
Block a user