diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 4147436330f..c5aa627a02c 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -610,6 +610,11 @@ class MutableLiteralBase : public LiteralBase {
   // Unhide const method from parent class.
   using LiteralBase::untyped_data;
 
+  template <typename NativeT>
+  void MutableEachCell(
+      std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
+          per_cell);
+
   // Copy values from 'src_literal' rooted at 'src_shape_index' into this
   // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
   // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
@@ -989,6 +994,24 @@ void LiteralBase::EachCell(
   } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
 }
 
+template <typename NativeT>
+void MutableLiteralBase::MutableEachCell(
+    std::function<NativeT(absl::Span<const int64> indices, NativeT value)>
+        per_cell) {
+  if (ShapeUtil::IsZeroElementArray(shape())) {
+    return;
+  }
+  std::vector<int64> indices(shape().rank(), 0);
+
+  Shape shape_dynamic = shape();
+  for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
+    shape_dynamic.set_dimensions(i, GetDynamicSize(i));
+  }
+  do {
+    Set<NativeT>(indices, per_cell(indices, Get<NativeT>(indices)));
+  } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
+}
+
 template <typename NativeT>
 inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
   CHECK(shape().IsArray());