STT-tensorflow/tensorflow/compiler/xla/literal.cc

2418 lines
86 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/literal.h"
#include <algorithm>
#include <cstring>
#include <functional>
#include <limits>
#include <numeric>
#include <vector>
#include "absl/base/casts.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
using absl::StrCat;
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
// Literals can be used as DMA targets, which can require alignment. We
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
// alignment.
constexpr int kMinimumAlignment = 64;
// Converts between little and big endian.
//
// Precondition: size % 2 == 0 (elements in the array are 16 bits long)
void ConvertEndianShort(string* bytes) {
CHECK_EQ(bytes->size() % 2, 0);
for (int64 i = 0, end = bytes->size(); i < end; i += 2) {
std::swap((*bytes)[i], (*bytes)[i + 1]);
}
}
void ConvertEndianShort(char* bytes, int64 size) {
CHECK_EQ(size % 2, 0);
for (int64 i = 0; i < size; i += 2) {
std::swap(bytes[i], bytes[i + 1]);
}
}
// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
// able to transparently access the raw 16-bit value contained within.
template <typename T>
T GetRawValue(T val) {
return val;
}
uint16 GetRawValue(Eigen::half val) { return val.x; }
bool LiteralProtoHasValues(const LiteralProto& proto) {
return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
proto.c64s_size() || proto.c128s_size() ||
proto.tuple_literals_size() || !proto.f16s().empty() ||
!proto.bf16s().empty() || !proto.u16s().empty() ||
!proto.s16s().empty();
}
} // namespace
LiteralBase::~LiteralBase() {}
std::ostream& operator<<(std::ostream& out, const Literal& literal) {
out << literal.ToString();
return out;
}
MutableLiteralBase::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape,
absl::Span<const int64> dimensions)
: dimensions(dimensions),
base(dimensions.size(), 0),
step(dimensions.size(), 1) {
if (!dimensions.empty()) {
// Selects the shape with the largest minor dimension as the one upon
// which to run the tight stride loop.
if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
} else {
minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
source_stride =
IndexUtil::GetDimensionStride(source_shape, minor_dimension);
}
minor_loop_size = dimensions[minor_dimension];
step[minor_dimension] = minor_loop_size;
}
}
Literal::Literal(const Shape& shape)
: Literal(shape, /*allocate_arrays=*/true) {}
void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
if (shape.IsTuple()) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
auto child_piece = Piece();
child_piece.set_subshape(&subshape);
SetPiece(subshape, &child_piece, allocate_arrays);
piece->emplace_back(std::move(child_piece));
}
} else if (shape.IsArray()) {
if (allocate_arrays) {
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
piece->size_bytes(), kMinimumAlignment)));
if (shape.is_dynamic()) {
CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
piece->set_dynamic_size_buffer(
static_cast<int32*>(tensorflow::port::AlignedMalloc(
piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
}
}
} else {
// If the shape is neither an array nor tuple, then it must be
// zero-sized. Otherwise, some memory needs to be allocated for it.
CHECK_EQ(piece->size_bytes(), 0);
}
}
Literal::Literal(const Shape& shape, bool allocate_arrays)
: MutableLiteralBase() {
shape_ = absl::make_unique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CHECK(&root_piece_->subshape() == shape_.get());
SetPiece(*shape_, root_piece_, allocate_arrays);
}
Literal::~Literal() {
if (root_piece_ != nullptr) {
DeallocateBuffers();
delete root_piece_;
}
}
void Literal::DeallocateBuffers() {
root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (piece->buffer() != nullptr) {
tensorflow::port::AlignedFree(piece->buffer());
}
if (piece->dynamic_size_buffer() != nullptr) {
tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
}
});
}
Literal::Literal(Literal&& other) : MutableLiteralBase() {
*this = std::move(other);
}
Literal& Literal::operator=(Literal&& other) {
DCHECK(&other.root_piece_->subshape() == other.shape_.get());
using std::swap;
swap(shape_, other.shape_);
swap(root_piece_, other.root_piece_);
DCHECK(&root_piece_->subshape() == shape_.get());
return *this;
}
Literal LiteralBase::CreateFromShape(const Shape& shape) {
Literal literal(shape);
literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (piece->subshape().IsArray()) {
memset(piece->untyped_data(), 0, piece->size_bytes());
}
});
return literal;
}
int32 LiteralBase::GetDynamicSize(int64 dim_index) const {
return GetDynamicSize(dim_index, {});
}
int32 LiteralBase::GetDynamicSize(int64 dim_index,
const ShapeIndex& shape_index) const {
return piece(shape_index).GetDynamicSize(dim_index);
}
absl::optional<int64> LiteralBase::GetFirstInteger() const {
switch (shape().element_type()) {
case U8:
return GetFirstElement<uint8>();
case U16:
return GetFirstElement<uint16>();
case U32:
return GetFirstElement<uint32>();
case U64: {
int64 v = GetFirstElement<uint64>();
if (v < 0) {
return absl::nullopt;
}
return v;
}
case S8:
return GetFirstElement<int8>();
case S16:
return GetFirstElement<int16>();
case S32:
return GetFirstElement<int32>();
case S64:
return GetFirstElement<int64>();
default:
return absl::nullopt;
}
}
template <typename NativeT>
Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, absl::Span<const int64> src_base,
absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
const int64 src_base_size = src_base.size();
const int64 dest_base_size = dest_base.size();
TF_RET_CHECK(src_literal.shape().rank() == src_base_size);
TF_RET_CHECK(shape().rank() == dest_base_size);
auto linear_index = [](const Shape& shape,
absl::Span<const int64> multi_index) {
return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
};
if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
// If any of the two shapes are scalars, we can just call the StridedCopy()
// directly, and we know we will be copying only one value.
TF_RET_CHECK(copy_size.empty());
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
linear_index(src_literal.shape(), src_base), 0, 1);
} else if (!ShapeUtil::IsZeroElementArray(shape()) &&
!ShapeUtil::IsZeroElementArray(src_literal.shape())) {
// Perform copy if neither src nor dest has dimensions with zero element,
// otherwise it's a no-op.
TF_RET_CHECK(src_base.size() == dest_base.size());
TF_RET_CHECK(src_base.size() == copy_size.size());
// Scan the source from minor, stepping in copy size blocks, then within
// the index enumeration functor, do a strided copy advancing source index
// by one (walking through the minor dimension), and destination index by
// proper stride size at the matching dimension.
DimensionVector src_indexes(src_base.size(), 0);
DimensionVector dest_indexes(dest_base.size(), 0);
MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size);
auto copy_proc = [&](absl::Span<const int64> indexes) {
// Map from multi-dimensional index, to source index.
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
src_indexes.begin(), std::plus<int64>());
// Map from multi-dimensional index, to destination index.
std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
dest_indexes.begin(), std::plus<int64>());
int64 src_index = linear_index(src_literal.shape(), src_indexes);
int64 dest_index = linear_index(shape(), dest_indexes);
// `this->` is needed to workaround MSVC bug: #16882
StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
src_literal.data<NativeT>(), src_index,
stride_config.source_stride, stride_config.minor_loop_size);
return true;
};
ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
stride_config.dimensions, stride_config.step,
copy_proc);
}
return Status::OK();
}
Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
absl::Span<const int64> src_index,
absl::Span<const int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
src_literal.shape(), src_index);
const int64 dest_linear_index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
char* dest_address =
static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
const char* source_address =
static_cast<const char*>(src_literal.untyped_data()) +
src_linear_index * primitive_size;
if (dest_address != source_address) {
memcpy(dest_address, source_address, primitive_size);
}
return Status::OK();
}
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
const LiteralProto& proto, bool prohibit_empty_literal) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
Shape shape(proto.shape());
if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) {
return InvalidArgument(
"Literal shape cannot include OPAQUE_TYPE sub-shape");
}
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("LiteralProto has no layout");
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
Literal literal(shape);
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
CHECK(i < proto_element->tuple_literals_size());
proto_element = &proto_element->tuple_literals(i);
}
if (piece->subshape().IsTuple()) {
if (proto_element->tuple_literals_size() !=
ShapeUtil::TupleElementCount(piece->subshape())) {
return InvalidArgument(
"Expected %d tuple elements in LiteralProto, has %d",
ShapeUtil::TupleElementCount(piece->subshape()),
proto_element->tuple_literals_size());
}
return Status::OK();
}
if (piece->subshape().element_type() == TOKEN) {
return Status::OK();
}
CHECK(piece->subshape().IsArray());
// When prohibit_empty_literal is false (allowing literal with no
// values), only copy from proto if the literal proto has values. This
// mode is used for a learned cost model.
if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
}
return Status::OK();
}));
return std::move(literal);
}
std::vector<Literal> Literal::DecomposeTuple() {
CHECK(shape().IsTuple());
std::vector<Literal> elements;
for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
/*allocate_arrays=*/false));
Literal& element = elements.back();
element.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* dest_piece) {
ShapeIndex src_index = {i};
for (int64 j : index) {
src_index.push_back(j);
}
Piece& src_piece = piece(src_index);
// Move the respective buffer over to the element Literal.
dest_piece->set_buffer(src_piece.buffer());
dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
src_piece.set_buffer(nullptr);
src_piece.set_dynamic_size_buffer(nullptr);
});
}
// Set this literal to be nil-shaped.
*this = Literal();
return elements;
}
namespace {
// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
// the array slices are indicated by dest_shape and src_shape respectively.
template <typename NativeT>
void CopyElementsBetween(absl::Span<NativeT> dest,
absl::Span<const NativeT> src, const Shape& dest_shape,
const Shape& src_shape) {
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(dest_shape.rank());
do {
dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
} while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
}
} // namespace
int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const {
CHECK(LayoutUtil::IsDenseArray(subshape()));
if (!subshape_->is_dynamic_dimension(dim_index)) {
// This is a static dimension, return size.
return subshape_->dimensions(dim_index);
}
CHECK_NE(dynamic_size_buffer(), nullptr);
return dynamic_size_buffer_[dim_index];
}
void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) {
CHECK(LayoutUtil::IsDenseArray(subshape()));
CHECK(subshape_->is_dynamic_dimension(dim_index));
if (dynamic_size_buffer() == nullptr) {
// Lazily initialize the dynamic size buffer.
set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
dynamic_size_buffer_bytes(), kMinimumAlignment)));
/*for (int64 i = 0; i < subshape().rank(); ++i) {
// Initialized to -1 to help debug.
dynamic_size_buffer_[i] = -1;
}*/
}
dynamic_size_buffer_[dim_index] = size;
}
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
bool only_dynamic_bound) {
CHECK(subshape_ != nullptr);
CHECK(src.subshape_ != nullptr);
if (ShapeUtil::Equal(subshape(), src.subshape())) {
// If the layouts are equal it's faster just to memcpy.
memcpy(buffer(), src.buffer(), src.size_bytes());
} else {
std::vector<int64> origin(subshape().rank(), 0);
switch (subshape().element_type()) {
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
case (XLA_T): \
if (only_dynamic_bound) { \
CopyElementsWithDynamicBound<NATIVE_T>(src); \
} else { \
CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
subshape(), src.subshape()); \
} \
break;
COPY_ELEMENTS(U8, uint8);
COPY_ELEMENTS(U16, uint16);
COPY_ELEMENTS(U32, uint32);
COPY_ELEMENTS(U64, uint64);
COPY_ELEMENTS(S8, int8);
COPY_ELEMENTS(S16, int16);
COPY_ELEMENTS(S32, int32);
COPY_ELEMENTS(S64, int64);
COPY_ELEMENTS(F16, half);
COPY_ELEMENTS(BF16, bfloat16);
COPY_ELEMENTS(F32, float);
COPY_ELEMENTS(F64, double);
COPY_ELEMENTS(C64, complex64);
COPY_ELEMENTS(C128, complex128);
COPY_ELEMENTS(PRED, bool);
#undef COPY_ELEMENTS
default:
return Unimplemented(
"Copying a Literal object with element type %s is not implemented.",
PrimitiveType_Name(subshape().element_type()));
}
}
DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
CHECK_NE(dynamic_size_buffer_, nullptr);
CHECK_NE(src.dynamic_size_buffer_, nullptr);
memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
src.dynamic_size_buffer_bytes());
}
return Status::OK();
}
void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) {
return SetDynamicSize(dim_index, {}, size);
}
void MutableLiteralBase::SetDynamicSize(int64 dim_index,
const ShapeIndex& shape_index,
int32 size) {
Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
CHECK_GE(subshape_->dimensions(dim_index), size);
if (subshape_->dimensions(dim_index) == size) {
subshape_->set_dynamic_dimension(dim_index, false);
return;
}
subshape_->set_dynamic_dimension(dim_index, true);
piece(shape_index).SetDynamicSize(dim_index, size);
}
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index,
bool only_dynamic_bound) {
const Shape& dest_subshape =
ShapeUtil::GetSubshape(shape(), dest_shape_index);
const Shape& src_subshape =
ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
if (only_dynamic_bound) {
auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
auto compact_shape =
dest_subshape.is_static() ? dest_subshape : src_subshape;
CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
<< compact_shape.ToString() << " vs " << bound_shape.ToString();
} else {
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
return InvalidArgument(
"Destination subshape incompatible with source subshape: %s vs %s",
ShapeUtil::HumanString(dest_subshape),
ShapeUtil::HumanString(src_subshape));
}
}
return root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
if (!piece->subshape().IsArray()) {
return Status::OK();
}
// Determine if this index is in the part of this literal that we want
// to copy over from src_literal.
bool in_subtree_to_copy = true;
for (int i = 0; i < dest_shape_index.size(); ++i) {
if (index[i] != dest_shape_index[i]) {
in_subtree_to_copy = false;
break;
}
}
if (!in_subtree_to_copy) {
return Status::OK();
}
// Construct the index of the corresponding piece in the source literal.
ShapeIndex src_piece_index = src_shape_index;
for (int64 i = dest_shape_index.size(), end = index.size(); i < end;
++i) {
src_piece_index.push_back(index[i]);
}
TF_RETURN_IF_ERROR(
piece->CopyFrom(src_literal.piece(src_piece_index),
/*only_dynamic_bound=*/only_dynamic_bound));
return Status::OK();
});
}
Status Literal::MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index) {
const Shape& dest_subshape =
ShapeUtil::GetSubshape(shape(), dest_shape_index);
if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
return InvalidArgument(
"Destination subshape not equal to source shape: %s vs %s",
ShapeUtil::HumanString(dest_subshape),
ShapeUtil::HumanString(src_literal.shape()));
}
src_literal.root_piece_->ForEachSubpiece(
[&](const ShapeIndex& src_index, const Piece& src_piece) {
if (!src_piece.subshape().IsArray()) {
return;
}
ShapeIndex dest_index = dest_shape_index;
for (int64 i : src_index) {
dest_index.push_back(i);
}
Piece& dest_piece = piece(dest_index);
tensorflow::port::AlignedFree(dest_piece.buffer());
tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
dest_piece.set_buffer(src_piece.buffer());
dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
});
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
delete src_literal.root_piece_;
src_literal.root_piece_ = new LiteralBase::Piece();
src_literal.root_piece_->set_subshape(src_literal.shape_.get());
return Status::OK();
}
Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
absl::Span<const int64> src_base,
absl::Span<const int64> dest_base,
absl::Span<const int64> copy_size) {
TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape());
TF_RET_CHECK(src_literal.shape().IsArray())
<< ShapeUtil::HumanString(src_literal.shape());
TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
switch (shape().element_type()) {
case U8:
return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
copy_size);
case U16:
return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
copy_size);
case U32:
return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
copy_size);
case U64:
return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
copy_size);
case S8:
return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
copy_size);
case S16:
return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
copy_size);
case S32:
return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
copy_size);
case S64:
return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
copy_size);
case F16:
return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
copy_size);
case BF16:
return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
copy_size);
case F32:
return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
copy_size);
case F64:
return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
copy_size);
case C64:
return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
copy_size);
case C128:
return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
copy_size);
case PRED:
return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
copy_size);
default:
break;
}
return Unimplemented(
"Copying a slice from a Literal object with element type %d is not "
"implemented.",
shape().element_type());
}
void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
CHECK(shape().IsArray());
CHECK_EQ(shape().rank(), 1);
CHECK_EQ(element_count(), values.bits());
CHECK_EQ(shape().element_type(), PRED);
for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
Set({i}, values.get(i));
}
}
Literal LiteralBase::Relayout(const Layout& new_layout,
const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
Literal result(new_shape);
TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (subshape.IsArray()) {
TF_CHECK_OK(result.CopyFrom(*this,
/*dest_shape_index=*/index,
/*src_shape_index=*/index));
}
});
return result;
}
Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
CHECK(bounded_shape.is_dynamic());
Literal result(bounded_shape);
ShapeUtil::ForEachSubshape(
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!subshape.IsArray()) {
return;
}
for (int64 i = 0; i < subshape.rank(); ++i) {
result.SetDynamicSize(i, subshape.dimensions(i));
}
});
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
return result;
}
Literal LiteralBase::ToStatic() const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
ShapeUtil::ForEachMutableSubshape(
&new_shape, [this](Shape* subshape, const ShapeIndex& index) {
if (!subshape->IsArray()) {
return;
}
for (int64 i = 0; i < subshape->rank(); ++i) {
subshape->set_dynamic_dimension(i, false);
subshape->set_dimensions(i, GetDynamicSize(i, index));
}
});
Literal result(new_shape);
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
return result;
}
StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!shape().IsArray()) {
return InvalidArgument("Broadcast only supports arrays.");
}
for (int64 i = 0, end = dimensions.size(); i < end; i++) {
TF_RET_CHECK(shape().dimensions(i) ==
result_shape.dimensions(dimensions[i]));
}
Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
for (int64 i = 0; i < dimensions.size(); ++i) {
int64 dynamic_size = GetDynamicSize(i);
result.SetDynamicSize(dimensions[i], dynamic_size);
}
ShapeUtil::ForEachIndex(
result_shape, [&](absl::Span<const int64> output_index) {
for (int64 i = 0, end = dimensions.size(); i < end; ++i) {
scratch_source_index[i] = output_index[dimensions[i]];
}
int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
result_shape, output_index);
int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
shape(), scratch_source_index);
memcpy(dest_data + primitive_size * dest_index,
source_data + primitive_size * source_index, primitive_size);
return true;
});
return std::move(result);
}
StatusOr<Literal> LiteralBase::Reshape(
absl::Span<const int64> dimensions) const {
if (!shape().IsArray()) {
return InvalidArgument("Reshape does not support tuples.");
}
if (shape().is_dynamic()) {
return Unimplemented("Dynamic reshape is not implemented.");
}
Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
} else {
output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
*output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
ShapeUtil::HumanString(shape()),
ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, shape().rank()))
<< "Given permutation is not a permutation of dimension numbers";
// To transpose the array, we just permute the dimensions and layout, and
// do a straight memory copy of the raw data set.
// This is considerably faster than iterating over every array element using
// the EachCell<>() and Set<>() APIs.
std::vector<int64> inverse_permutation = InversePermutation(permutation);
Shape permuted_shape =
ShapeUtil::PermuteDimensions(inverse_permutation, shape());
// Replace the layout with one affine to this shape, such that a
// transpose operation can be performed by leaving the flat values
// representation intact.
// For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
// The shape with affine layout resulting from that operation will be
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
// most minor.
//
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
// dimension has within the transposed array, a layout is affine if
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
// vector of the affine layout.
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
Layout* layout = permuted_shape.mutable_layout();
layout->clear_minor_to_major();
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
Literal new_literal(permuted_shape);
for (int64 i = 0; i < shape().rank(); i++) {
new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
}
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
Literal result_literal(result_shape);
DimensionVector new_indices(result_shape.rank());
CHECK(result_literal
.Populate<NativeT>([&](absl::Span<const int64> indices) {
for (int64 i = 0; i < result_shape.rank(); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
return Get<NativeT>(new_indices);
})
.ok());
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
if (shape().is_dynamic_dimension(dnum)) {
int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
result_literal.SetDynamicSize(dnum, dynamic_size);
}
}
return result_literal;
}
Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices) const {
CHECK(shape().IsArray()) << "tuple is not supported for slice";
DimensionVector result_dimensions;
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
CHECK_GE(start_indices[dnum], 0);
CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
<< "dnum = " << dnum;
int64 dimension = limit_indices[dnum] - start_indices[dnum];
CHECK_GE(dimension, 0) << "dnum = " << dnum;
result_dimensions.push_back(dimension);
}
auto result_shape =
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
switch (result_shape.element_type()) {
case PRED:
return SliceInternal<bool>(result_shape, start_indices);
case U8:
return SliceInternal<uint8>(result_shape, start_indices);
case U16:
return SliceInternal<uint16>(result_shape, start_indices);
case U32:
return SliceInternal<uint32>(result_shape, start_indices);
case U64:
return SliceInternal<uint64>(result_shape, start_indices);
case S8:
return SliceInternal<int8>(result_shape, start_indices);
case S16:
return SliceInternal<int16>(result_shape, start_indices);
case S32:
return SliceInternal<int32>(result_shape, start_indices);
case S64:
return SliceInternal<int64>(result_shape, start_indices);
case F16:
return SliceInternal<half>(result_shape, start_indices);
case BF16:
return SliceInternal<bfloat16>(result_shape, start_indices);
case F32:
return SliceInternal<float>(result_shape, start_indices);
case F64:
return SliceInternal<double>(result_shape, start_indices);
case C64:
return SliceInternal<complex64>(result_shape, start_indices);
case C128:
return SliceInternal<complex128>(result_shape, start_indices);
default:
LOG(FATAL) << "not yet implemented: "
<< PrimitiveType_Name(result_shape.element_type());
}
}
Literal LiteralBase::Clone() const {
Literal result(shape());
TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsDenseArray(subshape));
switch (subshape.element_type()) {
case PRED:
return Get<bool>(multi_index, shape_index) ? "true" : "false";
case S8:
return StrCat(Get<int8>(multi_index, shape_index));
case S16:
return StrCat(Get<int16>(multi_index, shape_index));
case S32:
return StrCat(Get<int32>(multi_index, shape_index));
case S64:
return StrCat(Get<int64>(multi_index, shape_index));
case U8:
return StrCat(Get<uint8>(multi_index, shape_index));
case U16:
return StrCat(Get<uint16>(multi_index, shape_index));
case U32:
return StrCat(Get<uint32>(multi_index, shape_index));
case U64:
return StrCat(Get<uint64>(multi_index, shape_index));
case F16:
return RoundTripFpToString(Get<half>(multi_index, shape_index));
case F32:
return RoundTripFpToString(Get<float>(multi_index, shape_index));
case BF16:
return RoundTripFpToString(Get<bfloat16>(multi_index, shape_index));
case F64:
return RoundTripFpToString(Get<double>(multi_index, shape_index));
case C64: {
complex64 c = Get<complex64>(multi_index, shape_index);
return StrCat("(", RoundTripFpToString(c.real()), ", ",
RoundTripFpToString(c.imag()), ")");
}
case C128: {
complex128 c = Get<complex128>(multi_index, shape_index);
return StrCat("(", RoundTripFpToString(c.real()), ", ",
RoundTripFpToString(c.imag()), ")");
}
default:
LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
}
}
absl::optional<int64> LiteralBase::GetIntegralAsS64(
absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
return Get<bool>(multi_index);
case S8:
return Get<int8>(multi_index);
case U8:
return Get<uint8>(multi_index);
case S16:
return Get<int16>(multi_index);
case U16:
return Get<uint16>(multi_index);
case S32:
return Get<int32>(multi_index);
case U32:
return Get<uint32>(multi_index);
case S64:
return Get<int64>(multi_index);
case U64:
return Get<uint64>(multi_index);
default:
return absl::nullopt;
}
}
absl::optional<double> LiteralBase::GetAsDouble(
absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case F16:
return static_cast<double>(Get<half>(multi_index));
case F32:
return static_cast<double>(Get<float>(multi_index));
case F64:
return Get<double>(multi_index);
case BF16:
return static_cast<double>(Get<bfloat16>(multi_index));
default:
return absl::nullopt;
}
}
absl::optional<complex128> LiteralBase::GetAsComplex128(
absl::Span<const int64> multi_index) const {
switch (shape().element_type()) {
case BF16:
return {{static_cast<double>(Get<bfloat16>(multi_index)), 0}};
case F16:
return {{static_cast<double>(Get<Eigen::half>(multi_index)), 0}};
case F32:
return {{Get<float>(multi_index), 0}};
case F64:
return {{Get<double>(multi_index), 0}};
case C64:
return {Get<complex64>(multi_index)};
case C128:
return {Get<complex128>(multi_index)};
case S8:
return {Get<int8>(multi_index)};
default:
return absl::nullopt;
}
}
size_t LiteralBase::Hash() const {
using tensorflow::Hash64;
using tensorflow::Hash64Combine;
size_t hash_value = ShapeUtil::Hash(shape());
ShapeUtil::ForEachSubshape(
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!subshape.IsArray()) {
return;
}
CHECK(LayoutUtil::IsDense(subshape.layout()));
hash_value = Hash64Combine(
hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
size_bytes(index)));
});
return hash_value;
}
Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
int64 value) {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
Set<bool>(multi_index, value);
break;
case U8:
Set<uint8>(multi_index, value);
break;
case S32:
Set<int32>(multi_index, value);
break;
case S64:
Set<int64>(multi_index, value);
break;
case U32:
Set<uint32>(multi_index, value);
break;
case U64:
Set<uint64>(multi_index, value);
break;
default:
return FailedPrecondition("Array element type is not integral: %s",
PrimitiveType_Name(shape().element_type()));
}
return Status::OK();
}
Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
double value) {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case F16:
Set<half>(multi_index, Eigen::half(value));
break;
case F32:
Set<float>(multi_index, value);
break;
case F64:
Set<double>(multi_index, value);
break;
case BF16:
Set<bfloat16>(multi_index, static_cast<bfloat16>(value));
break;
default:
return FailedPrecondition("Array element type is not floating: %s",
PrimitiveType_Name(shape().element_type()));
}
return Status::OK();
}
namespace {
string ShapeToString(bool print_layout, const Shape& shape) {
return print_layout ? ShapeUtil::HumanStringWithLayout(shape)
: ShapeUtil::HumanString(shape);
}
void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
bool print_shape, bool print_layout,
std::vector<string>* pieces);
void TupleToStringHelper(const LiteralBase& literal,
const ShapeIndex& shape_index, bool print_shape,
bool print_layout, std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
pieces->push_back("(\n");
std::vector<string> tuple_pieces;
for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
ShapeIndex element_index = shape_index;
element_index.push_back(i);
std::vector<string> element_pieces;
ToStringHelper(literal, element_index, print_shape, print_layout,
&element_pieces);
tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
}
pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
pieces->push_back("\n)");
}
void DenseArrayToStringHelper(const LiteralBase& literal,
const ShapeIndex& shape_index, bool print_shape,
bool print_layout, std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
int64 rank = subshape.rank();
std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
to_string_recursive = [&](absl::Span<const int64> dimensions,
std::vector<int64>* accum_indices) {
// dimensions.size() decreases by 1 at each recursive call,
// and accum_indices->size() increases by 1.
// Their sum is equal to the rank of the tensor.
CHECK_EQ(rank, dimensions.size() + accum_indices->size());
auto brace_to_string = [&](string brace) -> string {
// Handle 1D tensor
if (rank == 1) {
return brace;
}
// Handle the innermost tensor of a 2D+ tensor.
if (dimensions.size() == 1 && brace == "{") {
return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " ");
}
if (dimensions.size() == 1 && brace == "}") {
return StrCat(dimensions[0] <= 1 ? "" : " ", brace);
}
// Handle the non-innermost tensors of a 2D+ tensor.
if (brace == "{") {
const int64 accum_indices_size = accum_indices->size();
if (rank > 3 && !accum_indices->empty() &&
accum_indices_size < rank) {
int index = accum_indices->size() - 1;
int value = accum_indices->back();
return StrCat(brace, " /*i", index, "=", value, "*/\n");
}
return StrCat(brace, "\n");
}
return StrCat("\n", brace);
};
if (dimensions.empty()) {
// Display predicates as 0s and 1s so that the string is more dense.
string elem;
if (subshape.element_type() == PRED && rank > 0) {
elem = literal.Get<bool>(*accum_indices, shape_index) ? "1" : "0";
} else {
elem = literal.GetAsString(*accum_indices, shape_index);
}
pieces->push_back(elem);
} else {
pieces->push_back(brace_to_string("{"));
for (int i = 0; i < dimensions[0]; ++i) {
std::vector<int64> cloned_indices(*accum_indices);
cloned_indices.push_back(i);
to_string_recursive(dimensions.subspan(1), &cloned_indices);
if (i < dimensions[0] - 1) {
pieces->push_back(",");
pieces->push_back(dimensions.size() > 1 ? "\n" : " ");
}
}
pieces->push_back(brace_to_string("}"));
}
};
if (print_shape) {
pieces->push_back(ShapeToString(print_layout, subshape));
if (subshape.is_dynamic()) {
pieces->push_back("(");
for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
if (i < subshape.dimensions_size() - 1) {
pieces->push_back(",");
}
}
pieces->push_back(")");
}
pieces->push_back(" ");
}
std::vector<int64> indices = {};
std::vector<int64> dimensions;
dimensions.reserve(subshape.rank());
for (int64 i = 0; i < subshape.rank(); ++i) {
dimensions.push_back(literal.GetDynamicSize(i, shape_index));
}
to_string_recursive(dimensions, &indices);
}
void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
bool print_shape, bool print_layout,
std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
CHECK(LayoutUtil::HasLayout(literal.shape()));
CHECK(LayoutUtil::HasLayout(subshape));
if (subshape.IsTuple()) {
TupleToStringHelper(literal, shape_index, print_shape, print_layout,
pieces);
} else if (subshape.IsToken()) {
pieces->push_back("token");
} else {
CHECK(LayoutUtil::IsDenseArray(subshape));
DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
pieces);
}
}
} // namespace
string LiteralBase::ToString() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, /*print_shape=*/true,
/*print_layout=*/false, &pieces);
return absl::StrJoin(pieces, "");
}
string LiteralBase::ToStringWithoutShape() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, /*print_shape=*/false,
/*print_layout=*/false, &pieces);
return absl::StrJoin(pieces, "");
}
string LiteralBase::ToStringWithLayout() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, /*print_shape=*/true,
/*print_layout=*/true, &pieces);
return absl::StrJoin(pieces, "");
}
void LiteralBase::EachCellAsString(
const std::function<void(absl::Span<const int64> indices,
const string& value)>& per_cell) const {
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
shape(), /*linear_index=*/0);
do {
per_cell(indices, GetAsString(indices));
} while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
const ConverterType& converter) {
CHECK(src_literal.shape().IsArray());
Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = converter(src_data[i]);
}
return result_literal;
}
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
(std::is_same<NativeDestT, complex64>::value ||
std::is_same<NativeDestT, complex128>::value),
Literal>::type
ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
};
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
}
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
(!std::is_same<NativeDestT, complex64>::value &&
!std::is_same<NativeDestT, complex128>::value),
Literal>::type
ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
}
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
!std::is_same<NativeDestT, Eigen::half>::value),
Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return absl::bit_cast<NativeDestT>(GetRawValue(src));
};
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
}
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
std::is_same<NativeDestT, Eigen::half>::value),
Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
// cast to unsigned short and then use raw_uint16_to_half.
auto converter = [](NativeSrcT src) {
return Eigen::half_impl::raw_uint16_to_half(
absl::bit_cast<uint16>(GetRawValue(src)));
};
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
src_literal, converter);
}
// This template specialization is here to make the compiler happy. bit_cast has
// a static check that the types are the same size. This specialization should
// never be used because the source and destination types are checked for
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
typename primitive_util::PrimitiveTypeToNative<
primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
} else {
return ConvertBetweenNativeTypes<
typename primitive_util::PrimitiveTypeToNative<
primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
}
}
template <PrimitiveType primitive_src_type>
StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
PrimitiveType primitive_dest_type,
bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
bitcast);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S16)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U16)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH
case C64:
if (bitcast) {
break;
}
return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
case C128:
if (bitcast) {
break;
}
return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
// Other types are not yet supported.
default:
break;
}
return Unimplemented("Converting from type %s to type %s is not implemented.",
PrimitiveType_Name(src_literal.shape().element_type()),
PrimitiveType_Name(primitive_dest_type));
}
StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
PrimitiveType primitive_dest_type,
bool bitcast) {
TF_RET_CHECK(literal.shape().IsArray());
if (literal.shape().element_type() == primitive_dest_type) {
return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
bitcast);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S16)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
CONVERT_IF_DEST_TYPE_MATCHES(S64)
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U16)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
CONVERT_IF_DEST_TYPE_MATCHES(BF16)
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
return Unimplemented("%s from type %s to type %s is not implemented.",
(bitcast ? "Bitcast converting" : "Converting"),
PrimitiveType_Name(literal.shape().element_type()),
PrimitiveType_Name(primitive_dest_type));
}
}
} // namespace
StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
return InvalidArgument(
"Cannot bitcast convert from %s to %s, bit widths are different: %d != "
"%d",
PrimitiveType_Name(shape().element_type()),
PrimitiveType_Name(primitive_dest_type),
primitive_util::BitWidth(shape().element_type()),
primitive_util::BitWidth(primitive_dest_type));
}
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
if (!dest_shape.IsTuple()) {
return Convert(dest_shape.element_type());
}
std::vector<Literal> elements;
for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
auto element = LiteralSlice(*this, {i});
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
elements.push_back(std::move(new_element));
}
return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
absl::Span<Literal> elements) {
std::vector<Shape> element_shapes;
for (const Literal& element : elements) {
element_shapes.push_back(element.shape());
}
Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
/*allocate_arrays=*/false);
for (int i = 0, end = elements.size(); i < end; ++i) {
TF_CHECK_OK(
literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
}
return literal;
}
template <typename NativeT>
void LiteralBase::Piece::CopyElementsWithDynamicBound(
const LiteralBase::Piece& src) {
auto dest_shape = subshape();
auto src_shape = src.subshape();
// At least one shape has to be static as bound.
CHECK(dest_shape.is_static() || src_shape.is_static());
auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(dest_shape.rank());
do {
bool out_of_bound = false;
for (int64 i = 0; i < index.size(); ++i) {
// Do not copy elements beyond dynamic bound.
if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
out_of_bound = true;
}
}
if (out_of_bound) {
continue;
}
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
index)] =
src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
src_shape, index)];
} while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
}
template <typename NativeT>
bool LiteralBase::Piece::EqualElementsInternal(
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
if (multi_index->size() == subshape().rank()) {
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
multi_index->push_back(i);
if (!EqualElementsInternal<NativeT>(other, multi_index)) {
return false;
}
multi_index->pop_back();
}
return true;
}
bool LiteralBase::Piece::EqualDynamicSize(
const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
if (subshape().is_static()) {
return true;
}
for (int64 i = 0; i < subshape().rank(); ++i) {
if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
return false;
}
}
return true;
}
bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
if (subshape().is_static() &&
ShapeUtil::Equal(subshape(), other.subshape()) &&
LayoutUtil::IsDenseArray(subshape())) {
CHECK_EQ(size_bytes(), other.size_bytes());
return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
}
std::vector<int64> multi_index;
switch (subshape().element_type()) {
case PRED:
return EqualElementsInternal<bool>(other, &multi_index);
case S8:
return EqualElementsInternal<int8>(other, &multi_index);
case S16:
return EqualElementsInternal<int16>(other, &multi_index);
case S32:
return EqualElementsInternal<int32>(other, &multi_index);
case S64:
return EqualElementsInternal<int64>(other, &multi_index);
case U8:
return EqualElementsInternal<uint8>(other, &multi_index);
case U16:
return EqualElementsInternal<uint16>(other, &multi_index);
case U32:
return EqualElementsInternal<uint32>(other, &multi_index);
case U64:
return EqualElementsInternal<uint64>(other, &multi_index);
case F32:
return EqualElementsInternal<float>(other, &multi_index);
case F64:
return EqualElementsInternal<double>(other, &multi_index);
case F16:
return EqualElementsInternal<half>(other, &multi_index);
case BF16:
return EqualElementsInternal<bfloat16>(other, &multi_index);
case C64:
return EqualElementsInternal<complex64>(other, &multi_index);
case C128:
return EqualElementsInternal<complex128>(other, &multi_index);
default:
LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
<< PrimitiveType_Name(subshape().element_type());
}
}
bool LiteralBase::operator==(const LiteralBase& other) const {
// Checking the structure of tuple literals. Checks for dense arrays are
// performed below.
if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
return false;
}
return root_piece().ForEachSubpieceWithBool(
[&](const ShapeIndex& index, const Piece& piece) {
const Piece& other_piece = other.piece(index);
const Shape& subshape = piece.subshape();
const Shape& other_subshape = other_piece.subshape();
if (subshape.element_type() != other_subshape.element_type()) {
return false;
}
if (!piece.subshape().IsArray()) {
return true;
}
if (subshape.rank() != other_subshape.rank()) {
return false;
}
for (int64 i = 0; i < subshape.rank(); ++i) {
if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
return false;
}
}
if (!piece.EqualElements(other_piece)) {
return false;
}
return true;
});
}
namespace {
template <typename NativeT>
static bool AllElementsEqualValue(absl::Span<const NativeT> data,
NativeT value) {
for (int64 i = 0; i < data.size(); ++i) {
if (data[i] != value) {
return false;
}
}
return true;
}
} // namespace
bool LiteralBase::IsAll(int8 value) const {
return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
const Piece& piece) {
if (!piece.subshape().IsArray()) {
return true;
}
auto piece_is_all = [&]() {
switch (shape().element_type()) {
case U8:
if (value >= 0) {
return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
}
return false;
case U16:
if (value >= 0) {
return AllElementsEqualValue<uint16>(piece.data<uint16>(), value);
}
return false;
case U32:
if (value >= 0) {
return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
}
return false;
case U64:
if (value >= 0) {
return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
}
return false;
case S8:
return AllElementsEqualValue<int8>(piece.data<int8>(), value);
case S16:
return AllElementsEqualValue<int16>(piece.data<int16>(), value);
case S32:
return AllElementsEqualValue<int32>(piece.data<int32>(), value);
case S64:
return AllElementsEqualValue<int64>(piece.data<int64>(), value);
case F32:
return AllElementsEqualValue<float>(piece.data<float>(), value);
case F64:
return AllElementsEqualValue<double>(piece.data<double>(), value);
case F16:
return AllElementsEqualValue<half>(piece.data<half>(),
static_cast<half>(value));
case BF16:
return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
static_cast<bfloat16>(value));
case PRED:
if (value == 0) {
return AllElementsEqualValue<bool>(piece.data<bool>(), false);
}
if (value == 1) {
return AllElementsEqualValue<bool>(piece.data<bool>(), true);
}
return false;
default:
return false;
}
return false;
};
if (!piece_is_all()) {
return false;
}
return true;
});
}
bool LiteralBase::IsAllFloat(float value) const {
return root_piece().ForEachSubpieceWithBool(
[&](const ShapeIndex& index, const Piece& piece) {
if (!piece.subshape().IsArray()) {
return true;
}
switch (shape().element_type()) {
case F32:
return AllElementsEqualValue<float>(piece.data<float>(), value);
case F64:
return AllElementsEqualValue<double>(piece.data<double>(), value);
case F16:
return AllElementsEqualValue<half>(piece.data<half>(),
static_cast<half>(value));
case BF16:
return AllElementsEqualValue<bfloat16>(
piece.data<bfloat16>(), static_cast<bfloat16>(value));
default:
return false;
}
});
}
bool LiteralBase::IsAllComplex(complex64 value) const {
switch (shape().element_type()) {
case C64:
return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
value);
case C128:
return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
value);
default:
return false;
}
}
bool LiteralBase::IsAllFirst() const {
return root_piece().ForEachSubpieceWithBool(
[&](const ShapeIndex& index, const Piece& piece) {
if (!piece.subshape().IsArray()) {
return true;
}
// Empty shapes are not all the first element since there is no first
// element.
if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
return false;
}
auto piece_is_all = [&]() {
switch (piece.subshape().element_type()) {
case PRED: {
auto data = piece.data<bool>();
return AllElementsEqualValue<bool>(data, data[0]);
}
// 8 bit types
case S8: {
auto data = piece.data<int8>();
return AllElementsEqualValue<int8>(data, data[0]);
}
case U8: {
auto data = piece.data<uint8>();
return AllElementsEqualValue<uint8>(data, data[0]);
}
// 16 bit types
case BF16: {
auto data = piece.data<bfloat16>();
return AllElementsEqualValue<bfloat16>(data, data[0]);
}
case F16: {
auto data = piece.data<half>();
return AllElementsEqualValue<half>(data, data[0]);
}
case S16: {
auto data = piece.data<int16>();
return AllElementsEqualValue<int16>(data, data[0]);
}
case U16: {
auto data = piece.data<uint16>();
return AllElementsEqualValue<uint16>(data, data[0]);
}
// 32 bit types
case F32: {
auto data = piece.data<float>();
return AllElementsEqualValue<float>(data, data[0]);
}
case U32: {
auto data = piece.data<uint32>();
return AllElementsEqualValue<uint32>(data, data[0]);
}
case S32: {
auto data = piece.data<int32>();
return AllElementsEqualValue<int32>(data, data[0]);
}
// 64 bit types
case C64: {
auto data = piece.data<complex64>();
return AllElementsEqualValue<complex64>(data, data[0]);
}
case F64: {
auto data = piece.data<double>();
return AllElementsEqualValue<double>(data, data[0]);
}
case S64: {
auto data = piece.data<int64>();
return AllElementsEqualValue<int64>(data, data[0]);
}
case U64: {
auto data = piece.data<uint64>();
return AllElementsEqualValue<uint64>(data, data[0]);
}
case C128: {
auto data = piece.data<complex128>();
return AllElementsEqualValue<complex128>(data, data[0]);
}
default:
return false;
}
};
if (!piece_is_all()) {
return false;
}
return true;
});
}
bool LiteralBase::IsR1Iota() const {
if (!shape().IsArray()) {
return false;
}
if (shape().rank() != 1) {
return false;
}
auto is_iota_at_idx = [&](const int64 idx) {
switch (shape().element_type()) {
case U8:
return static_cast<int64>(Get<uint8>({idx})) == idx;
case U16:
return static_cast<int64>(Get<uint16>({idx})) == idx;
case U32:
return static_cast<int64>(Get<uint32>({idx})) == idx;
case U64:
return static_cast<int64>(Get<uint64>({idx})) == idx;
case S8:
return Get<int8>({idx}) == idx;
case S16:
return Get<int16>({idx}) == idx;
case S32:
return Get<int32>({idx}) == idx;
case S64:
return Get<int64>({idx}) == idx;
case F32:
return Get<float>({idx}) == idx;
case F64:
return Get<double>({idx}) == idx;
case F16:
return Get<half>({idx}) == static_cast<half>(idx);
case BF16:
return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
case C64:
return Get<complex64>({idx}) == complex64(idx, 0.0f);
case C128:
return Get<complex128>({idx}) == complex128(idx, 0.0f);
// pred, token, opaque, tuple, etc. are all not iota.
default:
return false;
}
};
const int64 elements = ShapeUtil::ElementsIn(shape());
for (int64 idx = 0; idx < elements; ++idx) {
if (!is_iota_at_idx(idx)) {
return false;
}
}
return true;
}
bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
CHECK(shape().IsArray());
switch (shape().element_type()) {
case U8:
return Get<uint8>(indices) == 0;
case U16:
return Get<uint16>(indices) == 0;
case U32:
return Get<uint32>(indices) == 0;
case U64:
return Get<uint64>(indices) == 0;
case S8:
return Get<int8>(indices) == 0;
case S16:
return Get<int16>(indices) == 0;
case S32:
return Get<int32>(indices) == 0;
case S64:
return Get<int64>(indices) == 0;
case F32:
return Get<float>(indices) == 0.0f;
case F64:
return Get<double>(indices) == 0.0;
case C64:
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case C128:
return Get<complex128>(indices) == complex128(0.0f, 0.0f);
case F16:
return Get<half>(indices) == static_cast<half>(0.0f);
case BF16:
return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
case PRED:
return Get<bool>(indices) == false;
default:
LOG(FATAL) << "Input literal must be an array.";
}
}
namespace {
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
const absl::Span<const NativeT> src) {
*dest = RepeatedFieldT(src.begin(), src.end());
}
} // namespace
void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
*proto->mutable_shape() = subshape().ToProto();
switch (subshape().element_type()) {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
case S8:
proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
element_count());
break;
case U8:
proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
element_count());
break;
case U32:
CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
break;
case U64:
CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
break;
case S32:
CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
break;
case S64:
CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
break;
case U16:
*proto->mutable_u16s() = string(
reinterpret_cast<const char*>(data<uint16_t>().data()), size_bytes());
if (!kLittleEndian) {
ConvertEndianShort(proto->mutable_u16s());
}
break;
case S16:
*proto->mutable_s16s() = string(
reinterpret_cast<const char*>(data<int16_t>().data()), size_bytes());
if (!kLittleEndian) {
ConvertEndianShort(proto->mutable_s16s());
}
break;
case F16:
*proto->mutable_f16s() = string(
reinterpret_cast<const char*>(data<half>().data()), size_bytes());
if (!kLittleEndian) {
ConvertEndianShort(proto->mutable_f16s());
}
break;
case BF16:
*proto->mutable_bf16s() = string(
reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
if (!kLittleEndian) {
ConvertEndianShort(proto->mutable_bf16s());
}
break;
case F32:
CopyToRepeatedField(proto->mutable_f32s(), data<float>());
break;
case F64:
CopyToRepeatedField(proto->mutable_f64s(), data<double>());
break;
case C64:
for (complex64 value : data<complex64>()) {
proto->add_c64s(value.real());
proto->add_c64s(value.imag());
}
break;
case C128:
for (complex128 value : data<complex128>()) {
proto->add_c128s(value.real());
proto->add_c128s(value.imag());
}
break;
case TUPLE:
case TOKEN:
// Nothing to do but assign the shape which is done above.
return;
default:
// TODO(b/111551621): Support serializing more PrimitiveTypes.
LOG(FATAL) << "Unhandled primitive type "
<< PrimitiveType_Name(subshape().element_type());
}
}
const void* LiteralBase::Piece::untyped_data() const {
CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
return buffer();
}
void* LiteralBase::Piece::untyped_data() {
CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape());
return buffer();
}
namespace {
template <typename RepeatedFieldT, typename NativeT>
Status CopyFromRepeatedField(absl::Span<NativeT> dest,
const RepeatedFieldT& src) {
if (dest.size() != src.size()) {
return InvalidArgument(
"Expected %lu elements in LiteralProto repeated field, has %d",
dest.size(), src.size());
}
std::copy(src.begin(), src.end(), dest.begin());
return Status::OK();
}
} // namespace
Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in
// MutableLiteralBase::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
Shape shape(proto.shape());
TF_RET_CHECK(LayoutUtil::HasLayout(shape));
TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
switch (subshape().element_type()) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
break;
case S8: {
auto s8_data = data<int8>();
TF_RET_CHECK(proto.s8s().size() == s8_data.size());
std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
} break;
case U8: {
auto u8_data = data<uint8>();
TF_RET_CHECK(proto.u8s().size() == u8_data.size());
std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
} break;
case S32:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
break;
case S64:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
break;
case U32:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
break;
case U64:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
break;
case S16: {
const string& s(proto.s16s());
TF_RET_CHECK(data<int16_t>().size() * sizeof(int16_t) == s.size());
memcpy(untyped_data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
}
} break;
case U16: {
const string& s(proto.u16s());
TF_RET_CHECK(data<uint16_t>().size() * sizeof(uint16_t) == s.size());
memcpy(untyped_data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
}
} break;
case F16: {
const string& s(proto.f16s());
TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
memcpy(untyped_data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
}
} break;
case BF16: {
const string& s(proto.bf16s());
TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
memcpy(untyped_data(), s.data(), s.size());
if (!kLittleEndian) {
ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
}
} break;
case F32:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
break;
case F64:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
break;
case C64: {
auto complex_data = data<complex64>();
TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
for (int64 i = 0; i < complex_data.size(); ++i) {
complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
}
break;
}
case C128: {
auto complex_data = data<complex128>();
const int64 complex_data_size_doubled = complex_data.size() * 2;
TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled);
for (int64 i = 0, end = complex_data.size(); i < end; ++i) {
complex_data[i] =
complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
}
break;
}
case TUPLE:
return InvalidArgument("Should not be called on tuple shapes: %s",
ShapeUtil::HumanString(subshape()));
default:
return InvalidArgument("Is called on unsupported shape: %s",
ShapeUtil::HumanString(subshape()));
}
return Status::OK();
}
LiteralProto LiteralBase::ToProto() const {
LiteralProto proto;
root_piece().ForEachSubpiece(
[&](const ShapeIndex& index, const Piece& piece) {
LiteralProto* proto_piece = &proto;
for (int64 i : index) {
while (proto_piece->tuple_literals_size() <= i) {
proto_piece->add_tuple_literals();
}
proto_piece = proto_piece->mutable_tuple_literals(i);
}
piece.WriteToProto(proto_piece);
});
return proto;
}
const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
return piece(shape_index).untyped_data();
}
int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
return piece(shape_index).size_bytes();
}
string LiteralBase::GetR1U8AsString() const {
CHECK(shape().IsArray());
CHECK_EQ(shape().rank(), 1);
CHECK_EQ(shape().element_type(), U8);
return string(absl::bit_cast<const char*>(data<uint8>().data()),
ShapeUtil::ElementsIn(shape()));
}
void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
Piece* src_piece,
Piece* dest_piece) {
DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
<< "src_piece has shape: "
<< ShapeUtil::HumanString(src_piece->subshape())
<< "dest_piece has shape: "
<< ShapeUtil::HumanString(dest_piece->subshape());
if (shape.IsTuple()) {
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
auto child_piece = Piece();
child_piece.set_subshape(&subshape);
CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
dest_piece->emplace_back(std::move(child_piece));
}
} else if (shape.IsArray()) {
dest_piece->set_buffer(src_piece->buffer());
dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
} else {
// If the shape is neither an array nor tuple, then it must be
// zero-sized. Otherwise, some memory needs to be allocated for it.
CHECK_EQ(dest_piece->size_bytes(), 0);
}
}
MutableLiteralBase::~MutableLiteralBase() {}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableBorrowingLiteral& literal)
: MutableLiteralBase() {
shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
}
MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
const MutableBorrowingLiteral& literal) {
shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
return *this;
}
MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
: MutableLiteralBase() {
shape_ = absl::make_unique<Shape>(literal->shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral literal, const ShapeIndex& view_root)
: MutableLiteralBase() {
shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
}
MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
const Shape& shape)
: MutableLiteralBase() {
shape_ = absl::make_unique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
CHECK(!shape_->IsTuple());
root_piece_ = new Piece();
root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
root_piece_->set_subshape(shape_.get());
}
MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
const Shape& shape)
: MutableLiteralBase() {
shape_ = absl::make_unique<Shape>(shape);
if (!shape_->IsTuple()) {
CHECK_EQ(src_buf_ptrs.size(), 1);
root_piece_ = new Piece();
root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0]));
root_piece_->set_subshape(shape_.get());
} else {
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
for (int i = 0; i < src_buf_ptrs.size(); ++i) {
Piece child_piece;
const auto& src_shape = shape_->tuple_shapes(i);
CHECK(src_shape.IsArray());
child_piece.set_subshape(&src_shape);
child_piece.set_buffer(src_buf_ptrs[i]);
root_piece_->emplace_back(std::move(child_piece));
}
}
}
MutableBorrowingLiteral::~MutableBorrowingLiteral() {
if (root_piece_ != nullptr) {
delete root_piece_;
}
}
LiteralSlice::LiteralSlice(const LiteralBase& literal)
: LiteralBase(), root_piece_(&literal.root_piece()) {}
LiteralSlice::LiteralSlice(const LiteralBase& literal,
const ShapeIndex& view_root)
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
CHECK(shape.IsTuple());
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const Shape& subshape = shape.tuple_shapes(i);
auto child_piece = Piece();
child_piece.set_subshape(&subshape);
if (subshape.IsTuple()) {
BuildPieceSubtree(subshape, &child_piece);
}
piece->emplace_back(std::move(child_piece));
}
}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
: LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(shape_->IsArray());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = Piece();
root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
root_piece_.set_subshape(shape_.get());
}
BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
const Shape& shape)
: LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(shape_->IsTuple());
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
root_piece_ = Piece();
root_piece_.set_subshape(shape_.get());
BuildPieceSubtree(*shape_, &root_piece_);
for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) {
const auto& src_shape = shape_->tuple_shapes(i);
CHECK(src_shape.IsArray());
root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
}
}
} // namespace xla