Add a method to update a literal using dynamic bounds if necessary.
PiperOrigin-RevId: 358255625 Change-Id: I06518781b9e341e0942cabd6099057b9845a5625
This commit is contained in:
parent
8d1b9a9f82
commit
f8a5d4e6cb
@ -610,6 +610,11 @@ class MutableLiteralBase : public LiteralBase {
|
|||||||
// Unhide const method from parent class.
|
// Unhide const method from parent class.
|
||||||
using LiteralBase::untyped_data;
|
using LiteralBase::untyped_data;
|
||||||
|
|
||||||
|
template <typename NativeT>
|
||||||
|
void MutableEachCell(
|
||||||
|
std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
|
||||||
|
per_cell);
|
||||||
|
|
||||||
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
||||||
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
||||||
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
|
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
|
||||||
@ -989,6 +994,24 @@ void LiteralBase::EachCell(
|
|||||||
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
|
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename NativeT>
|
||||||
|
void MutableLiteralBase::MutableEachCell(
|
||||||
|
std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
|
||||||
|
per_cell) {
|
||||||
|
if (ShapeUtil::IsZeroElementArray(shape())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<int64> indices(shape().rank(), 0);
|
||||||
|
|
||||||
|
Shape shape_dynamic = shape();
|
||||||
|
for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
|
||||||
|
shape_dynamic.set_dimensions(i, GetDynamicSize(i));
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
|
||||||
|
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
|
inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
|
||||||
CHECK(shape().IsArray());
|
CHECK(shape().IsArray());
|
||||||
|
Loading…
Reference in New Issue
Block a user