[XLA] Don't use temporary memory in evaluator's implementation of iota.
Fixes a fuzztest bug where we try to evaluate `f16[<huge>, 0] iota`. This is an empty array, but the old implementation would still try to do a huge allocation. PiperOrigin-RevId: 247943755
This commit is contained in:
parent
19872a69a9
commit
1e854a39ef
@ -4341,5 +4341,20 @@ TEST_F(HloEvaluatorTest, IsFiniteBf16) {
|
|||||||
::testing::ElementsAre(false, true, false, true, false, false));
|
::testing::ElementsAre(false, true, false, true, false, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
|
||||||
|
// array!).
|
||||||
|
TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
|
||||||
|
constexpr absl::string_view hlo_text = R"(
|
||||||
|
HloModule test
|
||||||
|
ENTRY t {
|
||||||
|
ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Literal actual_literal,
|
||||||
|
HloEvaluator().Evaluate(*m_->entry_computation(), {}));
|
||||||
|
EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -2505,32 +2505,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
std::is_floating_point<NativeT>::value>::type* = nullptr>
|
||||||
Status HandleIota(HloInstruction* instruction) {
|
Status HandleIota(HloInstruction* instruction) {
|
||||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||||
const int64 iota_size = iota->shape().dimensions(iota->iota_dimension());
|
|
||||||
// Avoid using std::vector since std::vector<bool> does not convert to
|
|
||||||
// absl::Span<bool>.
|
|
||||||
absl::InlinedVector<NativeT, 1> data(iota_size);
|
|
||||||
// We don't use std::iota for two reasons:
|
|
||||||
//
|
|
||||||
// (1) std:iota does not support bfloat16 and float16.
|
|
||||||
//
|
|
||||||
// (2) std::iota saturates for floating point types when the value is not
|
|
||||||
// representable, but the definition of HLO iota is the value as a
|
|
||||||
// 64-bit integer cast to the native type.
|
|
||||||
for (int64 i = 0; i < iota_size; ++i) {
|
|
||||||
// static_cast is required for Eigen::half (F16).
|
|
||||||
data[i] = static_cast<NativeT>(i);
|
|
||||||
}
|
|
||||||
auto result = LiteralUtil::CreateR1<NativeT>(data);
|
|
||||||
|
|
||||||
if (iota->shape().rank() > 1) {
|
Literal result(iota->shape());
|
||||||
TF_ASSIGN_OR_RETURN(
|
ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span<const int64> idx) {
|
||||||
parent_->evaluated_[iota],
|
result.Set(idx, static_cast<NativeT>(idx[iota->iota_dimension()]));
|
||||||
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
|
return true;
|
||||||
} else {
|
});
|
||||||
TF_RET_CHECK(iota->shape().rank() == 1);
|
|
||||||
parent_->evaluated_[iota] = std::move(result);
|
parent_->evaluated_[iota] = std::move(result);
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
template <
|
template <
|
||||||
|
Loading…
Reference in New Issue
Block a user