Dynamic literal support
PiperOrigin-RevId: 321836977 Change-Id: Ib5524846de424da20643c6982f6454614d9ffa07
This commit is contained in:
parent
c49b3f570d
commit
ce190ec244
@ -48,10 +48,6 @@ namespace {
|
||||
using absl::StrCat;
|
||||
|
||||
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
|
||||
// Literals can be used as DMA targets, which can require alignment. We
|
||||
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
|
||||
// alignment.
|
||||
constexpr int kMinimumAlignment = 64;
|
||||
|
||||
// Converts between little and big endian.
|
||||
//
|
||||
@ -137,14 +133,12 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
|
||||
}
|
||||
} else if (shape.IsArray()) {
|
||||
if (allocate_arrays) {
|
||||
// Literals can be used as DMA targets, which can require alignment. We
|
||||
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
|
||||
// alignment.
|
||||
constexpr int kMinimumAlignment = 64;
|
||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||
piece->size_bytes(), kMinimumAlignment)));
|
||||
if (shape.is_dynamic()) {
|
||||
CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
|
||||
piece->set_dynamic_size_buffer(
|
||||
static_cast<int32*>(tensorflow::port::AlignedMalloc(
|
||||
piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If the shape is neither an array nor tuple, then it must be
|
||||
@ -177,9 +171,6 @@ void Literal::DeallocateBuffers() {
|
||||
if (piece->buffer() != nullptr) {
|
||||
tensorflow::port::AlignedFree(piece->buffer());
|
||||
}
|
||||
if (piece->dynamic_size_buffer() != nullptr) {
|
||||
tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -208,15 +199,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
||||
return literal;
|
||||
}
|
||||
|
||||
int32 LiteralBase::GetDynamicSize(int64 dim_index) const {
|
||||
return GetDynamicSize(dim_index, {});
|
||||
}
|
||||
|
||||
int32 LiteralBase::GetDynamicSize(int64 dim_index,
|
||||
const ShapeIndex& shape_index) const {
|
||||
return piece(shape_index).GetDynamicSize(dim_index);
|
||||
}
|
||||
|
||||
absl::optional<int64> LiteralBase::GetFirstInteger() const {
|
||||
switch (shape().element_type()) {
|
||||
case U8:
|
||||
@ -399,9 +381,7 @@ std::vector<Literal> Literal::DecomposeTuple() {
|
||||
|
||||
// Move the respective buffer over to the element Literal.
|
||||
dest_piece->set_buffer(src_piece.buffer());
|
||||
dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
|
||||
src_piece.set_buffer(nullptr);
|
||||
src_piece.set_dynamic_size_buffer(nullptr);
|
||||
});
|
||||
}
|
||||
// Set this literal to be nil-shaped.
|
||||
@ -427,51 +407,23 @@ void CopyElementsBetween(absl::Span<NativeT> dest,
|
||||
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
|
||||
} while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const {
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape()));
|
||||
if (!subshape_->is_dynamic_dimension(dim_index)) {
|
||||
// This is a static dimension, return size.
|
||||
return subshape_->dimensions(dim_index);
|
||||
}
|
||||
CHECK_NE(dynamic_size_buffer(), nullptr);
|
||||
return dynamic_size_buffer_[dim_index];
|
||||
}
|
||||
|
||||
void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) {
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape()));
|
||||
CHECK(subshape_->is_dynamic_dimension(dim_index));
|
||||
if (dynamic_size_buffer() == nullptr) {
|
||||
// Lazily initialize the dynamic size buffer.
|
||||
set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
|
||||
dynamic_size_buffer_bytes(), kMinimumAlignment)));
|
||||
/*for (int64 i = 0; i < subshape().rank(); ++i) {
|
||||
// Initialized to -1 to help debug.
|
||||
dynamic_size_buffer_[i] = -1;
|
||||
}*/
|
||||
}
|
||||
dynamic_size_buffer_[dim_index] = size;
|
||||
}
|
||||
|
||||
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
|
||||
bool only_dynamic_bound) {
|
||||
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
|
||||
CHECK(subshape_ != nullptr);
|
||||
CHECK(src.subshape_ != nullptr);
|
||||
if (ShapeUtil::Equal(subshape(), src.subshape())) {
|
||||
// If the layouts are equal it's faster just to memcpy.
|
||||
memcpy(buffer(), src.buffer(), src.size_bytes());
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
|
||||
std::vector<int64> origin(subshape().rank(), 0);
|
||||
switch (subshape().element_type()) {
|
||||
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
|
||||
case (XLA_T): \
|
||||
if (only_dynamic_bound) { \
|
||||
CopyElementsWithDynamicBound<NATIVE_T>(src); \
|
||||
} else { \
|
||||
CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
|
||||
subshape(), src.subshape()); \
|
||||
} \
|
||||
break;
|
||||
COPY_ELEMENTS(U8, uint8);
|
||||
COPY_ELEMENTS(U16, uint16);
|
||||
@ -495,55 +447,22 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
|
||||
PrimitiveType_Name(subshape().element_type()));
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
|
||||
if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
|
||||
CHECK_NE(dynamic_size_buffer_, nullptr);
|
||||
CHECK_NE(src.dynamic_size_buffer_, nullptr);
|
||||
memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
|
||||
src.dynamic_size_buffer_bytes());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) {
|
||||
return SetDynamicSize(dim_index, {}, size);
|
||||
}
|
||||
|
||||
void MutableLiteralBase::SetDynamicSize(int64 dim_index,
|
||||
const ShapeIndex& shape_index,
|
||||
int32 size) {
|
||||
Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
|
||||
CHECK_GE(subshape_->dimensions(dim_index), size);
|
||||
if (subshape_->dimensions(dim_index) == size) {
|
||||
subshape_->set_dynamic_dimension(dim_index, false);
|
||||
return;
|
||||
}
|
||||
subshape_->set_dynamic_dimension(dim_index, true);
|
||||
piece(shape_index).SetDynamicSize(dim_index, size);
|
||||
}
|
||||
|
||||
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
|
||||
const ShapeIndex& dest_shape_index,
|
||||
const ShapeIndex& src_shape_index,
|
||||
bool only_dynamic_bound) {
|
||||
const ShapeIndex& src_shape_index) {
|
||||
const Shape& dest_subshape =
|
||||
ShapeUtil::GetSubshape(shape(), dest_shape_index);
|
||||
const Shape& src_subshape =
|
||||
ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
|
||||
if (only_dynamic_bound) {
|
||||
auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
|
||||
auto compact_shape =
|
||||
dest_subshape.is_static() ? dest_subshape : src_subshape;
|
||||
CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
|
||||
<< compact_shape.ToString() << " vs " << bound_shape.ToString();
|
||||
} else {
|
||||
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
|
||||
return InvalidArgument(
|
||||
"Destination subshape incompatible with source subshape: %s vs %s",
|
||||
ShapeUtil::HumanString(dest_subshape),
|
||||
ShapeUtil::HumanString(src_subshape));
|
||||
}
|
||||
}
|
||||
return root_piece_->ForEachMutableSubpieceWithStatus(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
if (!piece->subshape().IsArray()) {
|
||||
@ -567,9 +486,7 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
|
||||
for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
|
||||
src_piece_index.push_back(index[i]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
piece->CopyFrom(src_literal.piece(src_piece_index),
|
||||
/*only_dynamic_bound=*/only_dynamic_bound));
|
||||
TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -597,9 +514,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
||||
}
|
||||
Piece& dest_piece = piece(dest_index);
|
||||
tensorflow::port::AlignedFree(dest_piece.buffer());
|
||||
tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
|
||||
dest_piece.set_buffer(src_piece.buffer());
|
||||
dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
|
||||
});
|
||||
|
||||
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
|
||||
@ -714,41 +629,6 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
|
||||
CHECK(bounded_shape.is_dynamic());
|
||||
Literal result(bounded_shape);
|
||||
ShapeUtil::ForEachSubshape(
|
||||
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (!subshape.IsArray()) {
|
||||
return;
|
||||
}
|
||||
for (int64 i = 0; i < subshape.rank(); ++i) {
|
||||
result.SetDynamicSize(i, subshape.dimensions(i));
|
||||
}
|
||||
});
|
||||
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Literal LiteralBase::ToStatic() const {
|
||||
// Create new shape with 'new_layout' set at the given shape index.
|
||||
Shape new_shape = shape();
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
&new_shape, [this](Shape* subshape, const ShapeIndex& index) {
|
||||
if (!subshape->IsArray()) {
|
||||
return;
|
||||
}
|
||||
for (int64 i = 0; i < subshape->rank(); ++i) {
|
||||
subshape->set_dynamic_dimension(i, false);
|
||||
subshape->set_dimensions(i, GetDynamicSize(i, index));
|
||||
}
|
||||
});
|
||||
Literal result(new_shape);
|
||||
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<Literal> LiteralBase::Broadcast(
|
||||
const Shape& result_shape, absl::Span<const int64> dimensions) const {
|
||||
if (!shape().IsArray()) {
|
||||
@ -772,11 +652,6 @@ StatusOr<Literal> LiteralBase::Broadcast(
|
||||
const int64 primitive_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
|
||||
|
||||
for (int64 i = 0; i < dimensions.size(); ++i) {
|
||||
int64 dynamic_size = GetDynamicSize(i);
|
||||
result.SetDynamicSize(dimensions[i], dynamic_size);
|
||||
}
|
||||
|
||||
ShapeUtil::ForEachIndex(
|
||||
result_shape, [&](absl::Span<const int64> output_index) {
|
||||
for (int64 i = 0; i < dimensions.size(); ++i) {
|
||||
@ -799,9 +674,6 @@ StatusOr<Literal> LiteralBase::Reshape(
|
||||
if (!shape().IsArray()) {
|
||||
return InvalidArgument("Reshape does not support tuples.");
|
||||
}
|
||||
if (shape().is_dynamic()) {
|
||||
return Unimplemented("Dynamic reshape is not implemented.");
|
||||
}
|
||||
Literal output;
|
||||
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
|
||||
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
|
||||
@ -856,9 +728,6 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
||||
layout->add_minor_to_major(inverse_permutation[index]);
|
||||
}
|
||||
Literal new_literal(permuted_shape);
|
||||
for (int64 i = 0; i < shape().rank(); i++) {
|
||||
new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
|
||||
}
|
||||
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
|
||||
ShapeUtil::ByteSizeOf(shape()));
|
||||
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
|
||||
@ -878,14 +747,6 @@ Literal LiteralBase::SliceInternal(
|
||||
return Get<NativeT>(new_indices);
|
||||
})
|
||||
.ok());
|
||||
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
|
||||
if (shape().is_dynamic_dimension(dnum)) {
|
||||
int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
|
||||
CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
|
||||
dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
|
||||
result_literal.SetDynamicSize(dnum, dynamic_size);
|
||||
}
|
||||
}
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
@ -902,10 +763,9 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
|
||||
CHECK_GE(dimension, 0) << "dnum = " << dnum;
|
||||
result_dimensions.push_back(dimension);
|
||||
}
|
||||
auto result_shape =
|
||||
const auto result_shape =
|
||||
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
|
||||
LayoutUtil::MinorToMajor(shape()));
|
||||
ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
|
||||
switch (result_shape.element_type()) {
|
||||
case PRED:
|
||||
return SliceInternal<bool>(result_shape, start_indices);
|
||||
@ -1222,24 +1082,11 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
|
||||
|
||||
if (print_shape) {
|
||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||
if (subshape.is_dynamic()) {
|
||||
pieces->push_back("(");
|
||||
for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
|
||||
pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
|
||||
if (i < subshape.dimensions_size() - 1) {
|
||||
pieces->push_back(",");
|
||||
}
|
||||
}
|
||||
pieces->push_back(")");
|
||||
}
|
||||
pieces->push_back(" ");
|
||||
}
|
||||
std::vector<int64> indices = {};
|
||||
std::vector<int64> dimensions;
|
||||
dimensions.reserve(subshape.rank());
|
||||
for (int64 i = 0; i < subshape.rank(); ++i) {
|
||||
dimensions.push_back(literal.GetDynamicSize(i, shape_index));
|
||||
}
|
||||
std::vector<int64> dimensions(subshape.dimensions().begin(),
|
||||
subshape.dimensions().end());
|
||||
to_string_recursive(dimensions, &indices);
|
||||
}
|
||||
|
||||
@ -1527,44 +1374,13 @@ StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void LiteralBase::Piece::CopyElementsWithDynamicBound(
|
||||
const LiteralBase::Piece& src) {
|
||||
auto dest_shape = subshape();
|
||||
auto src_shape = src.subshape();
|
||||
|
||||
// At least one shape has to be static as bound.
|
||||
CHECK(dest_shape.is_static() || src_shape.is_static());
|
||||
auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
|
||||
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64> index(dest_shape.rank());
|
||||
do {
|
||||
bool out_of_bound = false;
|
||||
for (int64 i = 0; i < index.size(); ++i) {
|
||||
// Do not copy elements beyond dynamic bound.
|
||||
if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
|
||||
out_of_bound = true;
|
||||
}
|
||||
}
|
||||
if (out_of_bound) {
|
||||
continue;
|
||||
}
|
||||
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
|
||||
index)] =
|
||||
src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
|
||||
src_shape, index)];
|
||||
} while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
bool LiteralBase::Piece::EqualElementsInternal(
|
||||
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
|
||||
if (multi_index->size() == subshape().rank()) {
|
||||
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
|
||||
}
|
||||
for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
|
||||
for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
|
||||
multi_index->push_back(i);
|
||||
if (!EqualElementsInternal<NativeT>(other, multi_index)) {
|
||||
return false;
|
||||
@ -1574,26 +1390,10 @@ bool LiteralBase::Piece::EqualElementsInternal(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LiteralBase::Piece::EqualDynamicSize(
|
||||
const LiteralBase::Piece& other) const {
|
||||
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
|
||||
if (subshape().is_static()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < subshape().rank(); ++i) {
|
||||
if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
|
||||
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
|
||||
|
||||
if (subshape().is_static() &&
|
||||
ShapeUtil::Equal(subshape(), other.subshape()) &&
|
||||
if (ShapeUtil::Equal(subshape(), other.subshape()) &&
|
||||
LayoutUtil::IsDenseArray(subshape())) {
|
||||
CHECK_EQ(size_bytes(), other.size_bytes());
|
||||
return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
|
||||
@ -1636,33 +1436,17 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
|
||||
}
|
||||
|
||||
bool LiteralBase::operator==(const LiteralBase& other) const {
|
||||
// Checking the structure of tuple literals. Checks for dense arrays are
|
||||
// performed below.
|
||||
if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
|
||||
if (!ShapeUtil::Compatible(shape(), other.shape())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return root_piece().ForEachSubpieceWithBool(
|
||||
[&](const ShapeIndex& index, const Piece& piece) {
|
||||
const Piece& other_piece = other.piece(index);
|
||||
const Shape& subshape = piece.subshape();
|
||||
const Shape& other_subshape = other_piece.subshape();
|
||||
if (subshape.element_type() != other_subshape.element_type()) {
|
||||
return false;
|
||||
}
|
||||
if (!piece.subshape().IsArray()) {
|
||||
return true;
|
||||
}
|
||||
if (subshape.rank() != other_subshape.rank()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < subshape.rank(); ++i) {
|
||||
if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const Piece& other_piece = other.piece(index);
|
||||
if (!piece.EqualElements(other_piece)) {
|
||||
return false;
|
||||
}
|
||||
@ -2251,7 +2035,6 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
|
||||
}
|
||||
} else if (shape.IsArray()) {
|
||||
dest_piece->set_buffer(src_piece->buffer());
|
||||
dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
|
||||
} else {
|
||||
// If the shape is neither an array nor tuple, then it must be
|
||||
// zero-sized. Otherwise, some memory needs to be allocated for it.
|
||||
|
@ -112,10 +112,6 @@ class LiteralBase {
|
||||
template <typename NativeT>
|
||||
NativeT Get(absl::Span<const int64> multi_index) const;
|
||||
|
||||
// Get the dynamic size on dim_index in the literal at the given shape_index.
|
||||
int32 GetDynamicSize(int64 dim_index, const ShapeIndex& shape_index) const;
|
||||
int32 GetDynamicSize(int64 dim_index) const;
|
||||
|
||||
// Returns the element value at index (0, ..., 0), however many zeroes are
|
||||
// required for that index.
|
||||
template <typename NativeT>
|
||||
@ -285,18 +281,6 @@ class LiteralBase {
|
||||
// than being limited to a single array within the shape.
|
||||
Literal Relayout(const Shape& shape_with_layout) const;
|
||||
|
||||
// Generate a new literal whose static sizes are equal to the previous
|
||||
// literal's dynamic sizes.
|
||||
Literal ToStatic() const;
|
||||
|
||||
// Expand a static literal into a new one with a bounded dyanmic literal. The
|
||||
// static dimensions of the original literal becomes dynamic dimensions of the
|
||||
// new literal, where the argument `bounded_shape` becomes the bounded shape
|
||||
// of the new literal.
|
||||
//
|
||||
// Precondition: bounded_shape.is_dynamic()
|
||||
Literal ToBoundedDynamic(const Shape& bounded_shape) const;
|
||||
|
||||
// Creates a new literal by reshaping this literal to have the given
|
||||
// dimensions. The total number of elements must not change; The
|
||||
// implementation currently only supports monotonic dim0-major layouts.
|
||||
@ -370,22 +354,10 @@ class LiteralBase {
|
||||
template <typename NativeT>
|
||||
void Set(absl::Span<const int64> index, NativeT value);
|
||||
|
||||
int32 GetDynamicSize(int64 dim_index) const;
|
||||
void SetDynamicSize(int64 dim_index, int32 size);
|
||||
// Gets/sets the buffer holding the array data.
|
||||
char* buffer() const { return buffer_; }
|
||||
void set_buffer(char* buffer) { buffer_ = buffer; }
|
||||
|
||||
// Gets/sets the buffer holding dynamic sizes.
|
||||
int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
|
||||
void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
|
||||
dynamic_size_buffer_ = dynamic_size_buffer;
|
||||
}
|
||||
|
||||
int64 dynamic_size_buffer_bytes() const {
|
||||
return subshape().dimensions_size() * sizeof(int32);
|
||||
}
|
||||
|
||||
// Gets or sets the subshape of this piece. This reference points to a
|
||||
// subshape within the shape in the containing Literal (Literal::shape_).
|
||||
const Shape& subshape() const { return *subshape_; }
|
||||
@ -462,21 +434,15 @@ class LiteralBase {
|
||||
}
|
||||
|
||||
// Returns true if this piece and 'other' contain the same data. This piece
|
||||
// and 'other' must be array-shaped and compatible. If a literal has dynamic
|
||||
// shape, comparison is done only for the valid elements.
|
||||
// and 'other' must be array-shaped and compatible.
|
||||
bool EqualElements(const Piece& other) const;
|
||||
|
||||
// Returns true if this piece and other pieces have the same dynamic
|
||||
// dimension sizes.
|
||||
bool EqualDynamicSize(const Piece& other) const;
|
||||
|
||||
// Writes the shape and data (if array-shaped) into the given proto.
|
||||
void WriteToProto(LiteralProto* proto) const;
|
||||
|
||||
// Copy the data from 'src' into this piece's buffer. Shapes of this piece
|
||||
// and src must be compatible. If only_dynamic_bound is true, only elements
|
||||
// within dynamic bounds will be copied.
|
||||
Status CopyFrom(const Piece& src, bool only_dynamic_bound);
|
||||
// and src must be compatible.
|
||||
Status CopyFrom(const Piece& src);
|
||||
|
||||
// Copies the data from the given proto into this piece. The shape of this
|
||||
// piece must be equal (not just compatible) to the shape of the proto.
|
||||
@ -531,15 +497,9 @@ class LiteralBase {
|
||||
bool EqualElementsInternal(const Piece& other,
|
||||
std::vector<int64>* multi_index) const;
|
||||
|
||||
// Internal helper to copy elements from another given piece
|
||||
template <typename NativeT>
|
||||
void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
|
||||
|
||||
// For array-shaped pieces, this is the buffer holding the literal data.
|
||||
char* buffer_ = nullptr;
|
||||
|
||||
int32* dynamic_size_buffer_ = nullptr;
|
||||
|
||||
// The shape of piece. This points into the shape of the containing Literal
|
||||
// (Literal::shape_).
|
||||
const Shape* subshape_ = nullptr;
|
||||
@ -590,11 +550,6 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// mutate the shape as this can produce malformed Literals.
|
||||
Shape* mutable_shape_do_not_use() { return shape_.get(); }
|
||||
|
||||
// Set the dynamic size on dim_index in the literal at the given shape_index.
|
||||
void SetDynamicSize(int64 dim_index, const ShapeIndex& shape_index,
|
||||
int32 size);
|
||||
void SetDynamicSize(int64 dim_index, int32 size);
|
||||
|
||||
// Returns a pointer to the underlying buffer holding the array at the given
|
||||
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
|
||||
// is not array.
|
||||
@ -605,12 +560,10 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
||||
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
||||
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
|
||||
// rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
|
||||
// is true, only elements within dynamic bounds will be copied.
|
||||
// rooted at 'src_shape_index', but need not be arrays.
|
||||
Status CopyFrom(const LiteralSlice& src_literal,
|
||||
const ShapeIndex& dest_shape_index = {},
|
||||
const ShapeIndex& src_shape_index = {},
|
||||
bool only_dynamic_bound = false);
|
||||
const ShapeIndex& src_shape_index = {});
|
||||
|
||||
// Copies the values from src_literal, starting at src_base shape indexes,
|
||||
// to this literal, starting at dest_base, where the copy size in each
|
||||
@ -971,14 +924,9 @@ void LiteralBase::EachCell(
|
||||
return;
|
||||
}
|
||||
std::vector<int64> indices(shape().rank(), 0);
|
||||
|
||||
Shape shape_dynamic = shape();
|
||||
for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
|
||||
shape_dynamic.set_dimensions(i, GetDynamicSize(i));
|
||||
}
|
||||
do {
|
||||
per_cell(indices, Get<NativeT>(indices));
|
||||
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
|
||||
} while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
|
@ -149,16 +149,6 @@ TEST_F(LiteralUtilTest, R2ToString) {
|
||||
EXPECT_EQ(expected, literal.ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, R2DynamicToString) {
|
||||
auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
|
||||
literal.SetDynamicSize(0, {}, 2);
|
||||
const string expected = R"(s32[<=3,2](2,2) {
|
||||
{ 1, 2 },
|
||||
{ 3, 4 }
|
||||
})";
|
||||
EXPECT_EQ(expected, literal.ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, R3ToString) {
|
||||
const auto literal =
|
||||
LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
|
||||
@ -431,28 +421,6 @@ TEST_F(LiteralUtilTest, TupleEquality) {
|
||||
EXPECT_NE(tuple1, different_tuple);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, DynamicShapeEquality) {
|
||||
// Test equality with tuples.
|
||||
auto r1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
|
||||
r1.SetDynamicSize(0, {}, 1);
|
||||
auto r2 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
r2.SetDynamicSize(0, {}, 1);
|
||||
auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2});
|
||||
|
||||
// Tuple with the same elements. One element is shared with the original
|
||||
// tuple, the other is a clone of the element in the original tuple.
|
||||
auto r1_clone = LiteralUtil::CreateR1<float>({1.0, 3.0});
|
||||
r1_clone.SetDynamicSize(0, {}, 1);
|
||||
auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2});
|
||||
EXPECT_EQ(tuple1, tuple2);
|
||||
|
||||
// Tuple with different dynamic sizes.
|
||||
auto r2_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
r2_clone.SetDynamicSize(0, {}, 2);
|
||||
auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone});
|
||||
EXPECT_NE(tuple1, tuple_3);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, C64Equality) {
|
||||
// Test equality with tuples.
|
||||
auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
@ -724,47 +692,6 @@ TEST_F(LiteralUtilTest, TransposeR4) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, TransposeDynamicR2) {
|
||||
// F32[2, <=3] (2, 1)
|
||||
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
|
||||
original.SetDynamicSize(1, 1);
|
||||
// F32[<=3, 2] (1, 2)
|
||||
auto reshape = original.Transpose(/*permutation=*/{1, 0});
|
||||
|
||||
reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
|
||||
EXPECT_EQ(value, original.Get<float>({indices[1], indices[0]}));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, ToStaticR2) {
|
||||
// F32[2, <=3] (2, 1)
|
||||
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
|
||||
original.SetDynamicSize(1, 1);
|
||||
// F32[2, 1]
|
||||
auto static_literal = original.ToStatic();
|
||||
EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1}));
|
||||
EXPECT_TRUE(static_literal.shape().is_static());
|
||||
|
||||
static_literal.EachCell<float>(
|
||||
[&](absl::Span<const int64> indices, float value) {
|
||||
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, ToBoundedDynamicR2) {
|
||||
// F32[2, 1]
|
||||
auto original = LiteralUtil::CreateR2<float>({{1}, {4}});
|
||||
// F32[2, <=3] (2, 1)
|
||||
auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true});
|
||||
auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape);
|
||||
EXPECT_EQ(dynamic_literal.shape(), dynamic_shape);
|
||||
|
||||
dynamic_literal.EachCell<float>(
|
||||
[&](absl::Span<const int64> indices, float value) {
|
||||
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
|
||||
// Tests that using Relayout on an array is equivalent to creating it in the
|
||||
// target layout in the first place.
|
||||
@ -870,38 +797,6 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) {
|
||||
EXPECT_EQ(input_2x3x2, result);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, SliceR2Dynamic) {
|
||||
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
||||
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
||||
input_3x4.SetDynamicSize(1, 3);
|
||||
// slice second dim from dynamic size 3 to dynamic size 1.
|
||||
auto result = input_3x4.Slice({0, 1}, {2, 2});
|
||||
auto expected = LiteralUtil::CreateR2<uint32>({{2}, {6}});
|
||||
EXPECT_EQ(expected, result);
|
||||
EXPECT_EQ(result.GetDynamicSize(1), 1);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, SliceR2DynamicInBound) {
|
||||
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
||||
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
||||
input_3x4.SetDynamicSize(1, 1);
|
||||
auto result = input_3x4.Slice({0, 0}, {2, 2});
|
||||
auto expected = LiteralUtil::CreateR2<uint32>({{1}, {5}});
|
||||
EXPECT_EQ(expected, result);
|
||||
EXPECT_EQ(result.GetDynamicSize(1), 1);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) {
|
||||
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
||||
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
||||
input_3x4.SetDynamicSize(1, 1);
|
||||
auto result = input_3x4.Slice({0, 1}, {2, 3});
|
||||
auto expected = LiteralUtil::CreateR2<uint32>({{}, {}});
|
||||
EXPECT_EQ(expected, result);
|
||||
// Out of bound access clamps into 0 sized dimension.
|
||||
EXPECT_EQ(result.GetDynamicSize(1), 0);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateR1S64) {
|
||||
Literal output(ShapeUtil::MakeShape(S64, {1}));
|
||||
output.PopulateR1<int64>({77});
|
||||
@ -1615,7 +1510,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) {
|
||||
EXPECT_EQ(u1, r[3]);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) {
|
||||
TEST_F(LiteralUtilTest, LiteralSliceTest) {
|
||||
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
||||
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
||||
@ -2078,17 +1973,6 @@ TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
|
||||
LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, DynamicBroadcast) {
|
||||
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
|
||||
literal.SetDynamicSize(0, 1);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal broadcasted_literal,
|
||||
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
|
||||
/*dimensions=*/{1}));
|
||||
EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2<int64>({{1}, {1}}));
|
||||
EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, GetAsComplex128) {
|
||||
complex128 value = {1, 0};
|
||||
Literal c1 = LiteralUtil::CreateR0<complex128>(value);
|
||||
|
@ -440,10 +440,6 @@ Status HloEvaluator::HandleSetDimensionSize(
|
||||
Literal result(set_dimension_size->shape());
|
||||
memcpy(result.untyped_data(), operand_literal.untyped_data(),
|
||||
operand_literal.size_bytes());
|
||||
const Literal& size_literal =
|
||||
GetEvaluatedLiteralFor(set_dimension_size->operand(1));
|
||||
result.SetDynamicSize(set_dimension_size->dimension(),
|
||||
size_literal.Get<int32>({}));
|
||||
evaluated_[set_dimension_size] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -81,17 +81,8 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
||||
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
||||
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
||||
const auto& actual_shape = argument_buffers[i].on_device_shape();
|
||||
bool shape_match = true;
|
||||
if (expected_shape.is_dynamic()) {
|
||||
if (!ShapeUtil::DynamicArrayShapeIsCompatible(actual_shape,
|
||||
expected_shape)) {
|
||||
shape_match = false;
|
||||
}
|
||||
} else if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
||||
actual_shape)) {
|
||||
shape_match = false;
|
||||
}
|
||||
if (!shape_match) {
|
||||
return InvalidArgument(
|
||||
"Shape mismatch on parameter %d. Expected %s, but was %s.", i,
|
||||
ShapeUtil::HumanStringWithLayout(expected_shape),
|
||||
@ -109,18 +100,11 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
||||
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
||||
transfer_manager->TransferLiteralFromDevice(
|
||||
run_options->stream(), argument_buffers[p]));
|
||||
const auto& expected_shape = computation->parameter_instruction(p)->shape();
|
||||
if (expected_shape.is_dynamic()) {
|
||||
// Expand the input literal to expected shape.
|
||||
arg_literal = arg_literal.ToBoundedDynamic(expected_shape);
|
||||
}
|
||||
arg_literals.push_back(std::move(arg_literal));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Literal result_literal,
|
||||
Evaluate(*computation, arg_literals));
|
||||
// Shrink the generated dynamic shape into static shape.
|
||||
result_literal = result_literal.ToStatic();
|
||||
|
||||
// Transform the result literal back into a ShapedBuffer.
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
|
||||
|
@ -339,15 +339,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
TF_DCHECK_OK(ValidateShape(*shape));
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
|
||||
const Shape& from) {
|
||||
CHECK_EQ(to->rank(), from.rank());
|
||||
for (int64 i = 0; i < from.rank(); ++i) {
|
||||
to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
|
||||
}
|
||||
TF_DCHECK_OK(ValidateShape(*to));
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
|
||||
return primitive_util::IsIntegralType(shape.element_type());
|
||||
}
|
||||
|
@ -377,9 +377,6 @@ class ShapeUtil {
|
||||
// Appends a major dimension to the shape with the given bound.
|
||||
static void AppendMajorDimension(int bound, Shape* shape);
|
||||
|
||||
// Copy the dynamic dimensions property from one shape to another.
|
||||
static void CopyDynamicDimensions(Shape* to, const Shape& from);
|
||||
|
||||
// Returns an empty tuple shape. Can be used as a sentinel Shape value.
|
||||
static Shape MakeNil() { return MakeTupleShape({}); }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user