diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 4147436330f..c5aa627a02c 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -610,6 +610,11 @@ class MutableLiteralBase : public LiteralBase { // Unhide const method from parent class. using LiteralBase::untyped_data; + template <typename NativeT> + void MutableEachCell( + std::function<NativeT(absl::Span<const int64> indices, NativeT value)> + per_cell); + // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' @@ -989,6 +994,24 @@ void LiteralBase::EachCell( } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices))); } +template <typename NativeT> +void MutableLiteralBase::MutableEachCell( + std::function<NativeT(absl::Span<const int64> indices, NativeT value)> + per_cell) { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector<int64> indices(shape().rank(), 0); + + Shape shape_dynamic = shape(); + for (int64 i = 0; i < shape_dynamic.rank(); ++i) { + shape_dynamic.set_dimensions(i, GetDynamicSize(i)); + } + do { + Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices))); + } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices))); +} + template <typename NativeT> inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) { CHECK(shape().IsArray());