diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index cbbad741ce3..73c37d6b2f3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -2104,6 +2104,32 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, root_piece_->set_subshape(shape_.get()); } +MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span src_buf_ptrs, + const Shape& shape) + : MutableLiteralBase() { + shape_ = absl::make_unique(shape); + if (!shape_->IsTuple()) { + CHECK_EQ(src_buf_ptrs.size(), 1); + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptrs[0])); + root_piece_->set_subshape(shape_.get()); + } else { + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + Piece child_piece; + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(src_shape.IsArray()); + child_piece.set_subshape(&src_shape); + child_piece.set_buffer(src_buf_ptrs[i]); + root_piece_->emplace_back(std::move(child_piece)); + } + } +} + MutableBorrowingLiteral::~MutableBorrowingLiteral() { if (root_piece_ != nullptr) { delete root_piece_; diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1553d042e80..a2be92fbf5b 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -776,6 +776,10 @@ class MutableBorrowingLiteral : public MutableLiteralBase { const ShapeIndex& view_root); MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Create a literal from a list of buffers and a shape. + // Returns a tuple literal if `shape` is a tuple type. + MutableBorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape); + private: // Recursively copies the subtree from the `src_piece` at the given child // index to the `dest_piece`. For buffers only the pointers are copied, but