Add a method to update a literal using dynamic bounds if necessary.

PiperOrigin-RevId: 358255625
Change-Id: I06518781b9e341e0942cabd6099057b9845a5625
This commit is contained in:
A. Unique TensorFlower 2021-02-18 13:45:18 -08:00 committed by TensorFlower Gardener
parent 8d1b9a9f82
commit f8a5d4e6cb

View File

@ -610,6 +610,11 @@ class MutableLiteralBase : public LiteralBase {
// Unhide const method from parent class.
using LiteralBase::untyped_data;
template <typename NativeT>
void MutableEachCell(
std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
per_cell);
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
@ -989,6 +994,24 @@ void LiteralBase::EachCell(
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
}
template <typename NativeT>
void MutableLiteralBase::MutableEachCell(
std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
per_cell) {
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices(shape().rank(), 0);
Shape shape_dynamic = shape();
for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
shape_dynamic.set_dimensions(i, GetDynamicSize(i));
}
do {
Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
}
template <typename NativeT>
inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
CHECK(shape().IsArray());