Dynamic literal support
PiperOrigin-RevId: 321836977 Change-Id: Ib5524846de424da20643c6982f6454614d9ffa07
This commit is contained in:
parent
c49b3f570d
commit
ce190ec244
@ -48,10 +48,6 @@ namespace {
|
|||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
|
|
||||||
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
|
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
|
||||||
// Literals can be used as DMA targets, which can require alignment. We
|
|
||||||
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
|
|
||||||
// alignment.
|
|
||||||
constexpr int kMinimumAlignment = 64;
|
|
||||||
|
|
||||||
// Converts between little and big endian.
|
// Converts between little and big endian.
|
||||||
//
|
//
|
||||||
@ -137,14 +133,12 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
|
|||||||
}
|
}
|
||||||
} else if (shape.IsArray()) {
|
} else if (shape.IsArray()) {
|
||||||
if (allocate_arrays) {
|
if (allocate_arrays) {
|
||||||
|
// Literals can be used as DMA targets, which can require alignment. We
|
||||||
|
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
|
||||||
|
// alignment.
|
||||||
|
constexpr int kMinimumAlignment = 64;
|
||||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||||
piece->size_bytes(), kMinimumAlignment)));
|
piece->size_bytes(), kMinimumAlignment)));
|
||||||
if (shape.is_dynamic()) {
|
|
||||||
CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
|
|
||||||
piece->set_dynamic_size_buffer(
|
|
||||||
static_cast<int32*>(tensorflow::port::AlignedMalloc(
|
|
||||||
piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If the shape is neither an array nor tuple, then it must be
|
// If the shape is neither an array nor tuple, then it must be
|
||||||
@ -177,9 +171,6 @@ void Literal::DeallocateBuffers() {
|
|||||||
if (piece->buffer() != nullptr) {
|
if (piece->buffer() != nullptr) {
|
||||||
tensorflow::port::AlignedFree(piece->buffer());
|
tensorflow::port::AlignedFree(piece->buffer());
|
||||||
}
|
}
|
||||||
if (piece->dynamic_size_buffer() != nullptr) {
|
|
||||||
tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,15 +199,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
|||||||
return literal;
|
return literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32 LiteralBase::GetDynamicSize(int64 dim_index) const {
|
|
||||||
return GetDynamicSize(dim_index, {});
|
|
||||||
}
|
|
||||||
|
|
||||||
int32 LiteralBase::GetDynamicSize(int64 dim_index,
|
|
||||||
const ShapeIndex& shape_index) const {
|
|
||||||
return piece(shape_index).GetDynamicSize(dim_index);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::optional<int64> LiteralBase::GetFirstInteger() const {
|
absl::optional<int64> LiteralBase::GetFirstInteger() const {
|
||||||
switch (shape().element_type()) {
|
switch (shape().element_type()) {
|
||||||
case U8:
|
case U8:
|
||||||
@ -399,9 +381,7 @@ std::vector<Literal> Literal::DecomposeTuple() {
|
|||||||
|
|
||||||
// Move the respective buffer over to the element Literal.
|
// Move the respective buffer over to the element Literal.
|
||||||
dest_piece->set_buffer(src_piece.buffer());
|
dest_piece->set_buffer(src_piece.buffer());
|
||||||
dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
|
|
||||||
src_piece.set_buffer(nullptr);
|
src_piece.set_buffer(nullptr);
|
||||||
src_piece.set_dynamic_size_buffer(nullptr);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// Set this literal to be nil-shaped.
|
// Set this literal to be nil-shaped.
|
||||||
@ -427,51 +407,23 @@ void CopyElementsBetween(absl::Span<NativeT> dest,
|
|||||||
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
|
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
|
||||||
} while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
|
} while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const {
|
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
|
||||||
CHECK(LayoutUtil::IsDenseArray(subshape()));
|
|
||||||
if (!subshape_->is_dynamic_dimension(dim_index)) {
|
|
||||||
// This is a static dimension, return size.
|
|
||||||
return subshape_->dimensions(dim_index);
|
|
||||||
}
|
|
||||||
CHECK_NE(dynamic_size_buffer(), nullptr);
|
|
||||||
return dynamic_size_buffer_[dim_index];
|
|
||||||
}
|
|
||||||
|
|
||||||
void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) {
|
|
||||||
CHECK(LayoutUtil::IsDenseArray(subshape()));
|
|
||||||
CHECK(subshape_->is_dynamic_dimension(dim_index));
|
|
||||||
if (dynamic_size_buffer() == nullptr) {
|
|
||||||
// Lazily initialize the dynamic size buffer.
|
|
||||||
set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
|
|
||||||
dynamic_size_buffer_bytes(), kMinimumAlignment)));
|
|
||||||
/*for (int64 i = 0; i < subshape().rank(); ++i) {
|
|
||||||
// Initialized to -1 to help debug.
|
|
||||||
dynamic_size_buffer_[i] = -1;
|
|
||||||
}*/
|
|
||||||
}
|
|
||||||
dynamic_size_buffer_[dim_index] = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
|
|
||||||
bool only_dynamic_bound) {
|
|
||||||
CHECK(subshape_ != nullptr);
|
CHECK(subshape_ != nullptr);
|
||||||
CHECK(src.subshape_ != nullptr);
|
CHECK(src.subshape_ != nullptr);
|
||||||
if (ShapeUtil::Equal(subshape(), src.subshape())) {
|
if (ShapeUtil::Equal(subshape(), src.subshape())) {
|
||||||
// If the layouts are equal it's faster just to memcpy.
|
// If the layouts are equal it's faster just to memcpy.
|
||||||
memcpy(buffer(), src.buffer(), src.size_bytes());
|
memcpy(buffer(), src.buffer(), src.size_bytes());
|
||||||
} else {
|
} else {
|
||||||
|
TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
|
||||||
std::vector<int64> origin(subshape().rank(), 0);
|
std::vector<int64> origin(subshape().rank(), 0);
|
||||||
switch (subshape().element_type()) {
|
switch (subshape().element_type()) {
|
||||||
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
|
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
|
||||||
case (XLA_T): \
|
case (XLA_T): \
|
||||||
if (only_dynamic_bound) { \
|
CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
|
||||||
CopyElementsWithDynamicBound<NATIVE_T>(src); \
|
subshape(), src.subshape()); \
|
||||||
} else { \
|
|
||||||
CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
|
|
||||||
subshape(), src.subshape()); \
|
|
||||||
} \
|
|
||||||
break;
|
break;
|
||||||
COPY_ELEMENTS(U8, uint8);
|
COPY_ELEMENTS(U8, uint8);
|
||||||
COPY_ELEMENTS(U16, uint16);
|
COPY_ELEMENTS(U16, uint16);
|
||||||
@ -495,54 +447,21 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
|
|||||||
PrimitiveType_Name(subshape().element_type()));
|
PrimitiveType_Name(subshape().element_type()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
|
|
||||||
if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
|
|
||||||
CHECK_NE(dynamic_size_buffer_, nullptr);
|
|
||||||
CHECK_NE(src.dynamic_size_buffer_, nullptr);
|
|
||||||
memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
|
|
||||||
src.dynamic_size_buffer_bytes());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) {
|
|
||||||
return SetDynamicSize(dim_index, {}, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
void MutableLiteralBase::SetDynamicSize(int64 dim_index,
|
|
||||||
const ShapeIndex& shape_index,
|
|
||||||
int32 size) {
|
|
||||||
Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
|
|
||||||
CHECK_GE(subshape_->dimensions(dim_index), size);
|
|
||||||
if (subshape_->dimensions(dim_index) == size) {
|
|
||||||
subshape_->set_dynamic_dimension(dim_index, false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
subshape_->set_dynamic_dimension(dim_index, true);
|
|
||||||
piece(shape_index).SetDynamicSize(dim_index, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
|
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
|
||||||
const ShapeIndex& dest_shape_index,
|
const ShapeIndex& dest_shape_index,
|
||||||
const ShapeIndex& src_shape_index,
|
const ShapeIndex& src_shape_index) {
|
||||||
bool only_dynamic_bound) {
|
|
||||||
const Shape& dest_subshape =
|
const Shape& dest_subshape =
|
||||||
ShapeUtil::GetSubshape(shape(), dest_shape_index);
|
ShapeUtil::GetSubshape(shape(), dest_shape_index);
|
||||||
const Shape& src_subshape =
|
const Shape& src_subshape =
|
||||||
ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
|
ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
|
||||||
if (only_dynamic_bound) {
|
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
|
||||||
auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
|
return InvalidArgument(
|
||||||
auto compact_shape =
|
"Destination subshape incompatible with source subshape: %s vs %s",
|
||||||
dest_subshape.is_static() ? dest_subshape : src_subshape;
|
ShapeUtil::HumanString(dest_subshape),
|
||||||
CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
|
ShapeUtil::HumanString(src_subshape));
|
||||||
<< compact_shape.ToString() << " vs " << bound_shape.ToString();
|
|
||||||
} else {
|
|
||||||
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Destination subshape incompatible with source subshape: %s vs %s",
|
|
||||||
ShapeUtil::HumanString(dest_subshape),
|
|
||||||
ShapeUtil::HumanString(src_subshape));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return root_piece_->ForEachMutableSubpieceWithStatus(
|
return root_piece_->ForEachMutableSubpieceWithStatus(
|
||||||
[&](const ShapeIndex& index, Piece* piece) {
|
[&](const ShapeIndex& index, Piece* piece) {
|
||||||
@ -567,9 +486,7 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
|
|||||||
for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
|
for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
|
||||||
src_piece_index.push_back(index[i]);
|
src_piece_index.push_back(index[i]);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
|
||||||
piece->CopyFrom(src_literal.piece(src_piece_index),
|
|
||||||
/*only_dynamic_bound=*/only_dynamic_bound));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -597,9 +514,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
|||||||
}
|
}
|
||||||
Piece& dest_piece = piece(dest_index);
|
Piece& dest_piece = piece(dest_index);
|
||||||
tensorflow::port::AlignedFree(dest_piece.buffer());
|
tensorflow::port::AlignedFree(dest_piece.buffer());
|
||||||
tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
|
|
||||||
dest_piece.set_buffer(src_piece.buffer());
|
dest_piece.set_buffer(src_piece.buffer());
|
||||||
dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
|
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
|
||||||
@ -714,41 +629,6 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
|
|
||||||
CHECK(bounded_shape.is_dynamic());
|
|
||||||
Literal result(bounded_shape);
|
|
||||||
ShapeUtil::ForEachSubshape(
|
|
||||||
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
|
||||||
if (!subshape.IsArray()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (int64 i = 0; i < subshape.rank(); ++i) {
|
|
||||||
result.SetDynamicSize(i, subshape.dimensions(i));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
Literal LiteralBase::ToStatic() const {
|
|
||||||
// Create new shape with 'new_layout' set at the given shape index.
|
|
||||||
Shape new_shape = shape();
|
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
|
||||||
&new_shape, [this](Shape* subshape, const ShapeIndex& index) {
|
|
||||||
if (!subshape->IsArray()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (int64 i = 0; i < subshape->rank(); ++i) {
|
|
||||||
subshape->set_dynamic_dimension(i, false);
|
|
||||||
subshape->set_dimensions(i, GetDynamicSize(i, index));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
Literal result(new_shape);
|
|
||||||
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<Literal> LiteralBase::Broadcast(
|
StatusOr<Literal> LiteralBase::Broadcast(
|
||||||
const Shape& result_shape, absl::Span<const int64> dimensions) const {
|
const Shape& result_shape, absl::Span<const int64> dimensions) const {
|
||||||
if (!shape().IsArray()) {
|
if (!shape().IsArray()) {
|
||||||
@ -772,11 +652,6 @@ StatusOr<Literal> LiteralBase::Broadcast(
|
|||||||
const int64 primitive_size =
|
const int64 primitive_size =
|
||||||
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
|
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
|
||||||
|
|
||||||
for (int64 i = 0; i < dimensions.size(); ++i) {
|
|
||||||
int64 dynamic_size = GetDynamicSize(i);
|
|
||||||
result.SetDynamicSize(dimensions[i], dynamic_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
ShapeUtil::ForEachIndex(
|
ShapeUtil::ForEachIndex(
|
||||||
result_shape, [&](absl::Span<const int64> output_index) {
|
result_shape, [&](absl::Span<const int64> output_index) {
|
||||||
for (int64 i = 0; i < dimensions.size(); ++i) {
|
for (int64 i = 0; i < dimensions.size(); ++i) {
|
||||||
@ -799,9 +674,6 @@ StatusOr<Literal> LiteralBase::Reshape(
|
|||||||
if (!shape().IsArray()) {
|
if (!shape().IsArray()) {
|
||||||
return InvalidArgument("Reshape does not support tuples.");
|
return InvalidArgument("Reshape does not support tuples.");
|
||||||
}
|
}
|
||||||
if (shape().is_dynamic()) {
|
|
||||||
return Unimplemented("Dynamic reshape is not implemented.");
|
|
||||||
}
|
|
||||||
Literal output;
|
Literal output;
|
||||||
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
|
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
|
||||||
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
|
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
|
||||||
@ -856,9 +728,6 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
|||||||
layout->add_minor_to_major(inverse_permutation[index]);
|
layout->add_minor_to_major(inverse_permutation[index]);
|
||||||
}
|
}
|
||||||
Literal new_literal(permuted_shape);
|
Literal new_literal(permuted_shape);
|
||||||
for (int64 i = 0; i < shape().rank(); i++) {
|
|
||||||
new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
|
|
||||||
}
|
|
||||||
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
|
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
|
||||||
ShapeUtil::ByteSizeOf(shape()));
|
ShapeUtil::ByteSizeOf(shape()));
|
||||||
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
|
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
|
||||||
@ -878,14 +747,6 @@ Literal LiteralBase::SliceInternal(
|
|||||||
return Get<NativeT>(new_indices);
|
return Get<NativeT>(new_indices);
|
||||||
})
|
})
|
||||||
.ok());
|
.ok());
|
||||||
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
|
|
||||||
if (shape().is_dynamic_dimension(dnum)) {
|
|
||||||
int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
|
|
||||||
CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
|
|
||||||
dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
|
|
||||||
result_literal.SetDynamicSize(dnum, dynamic_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result_literal;
|
return result_literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -902,10 +763,9 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
|
|||||||
CHECK_GE(dimension, 0) << "dnum = " << dnum;
|
CHECK_GE(dimension, 0) << "dnum = " << dnum;
|
||||||
result_dimensions.push_back(dimension);
|
result_dimensions.push_back(dimension);
|
||||||
}
|
}
|
||||||
auto result_shape =
|
const auto result_shape =
|
||||||
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
|
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
|
||||||
LayoutUtil::MinorToMajor(shape()));
|
LayoutUtil::MinorToMajor(shape()));
|
||||||
ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
|
|
||||||
switch (result_shape.element_type()) {
|
switch (result_shape.element_type()) {
|
||||||
case PRED:
|
case PRED:
|
||||||
return SliceInternal<bool>(result_shape, start_indices);
|
return SliceInternal<bool>(result_shape, start_indices);
|
||||||
@ -1222,24 +1082,11 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
|
|||||||
|
|
||||||
if (print_shape) {
|
if (print_shape) {
|
||||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||||
if (subshape.is_dynamic()) {
|
|
||||||
pieces->push_back("(");
|
|
||||||
for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
|
|
||||||
pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
|
|
||||||
if (i < subshape.dimensions_size() - 1) {
|
|
||||||
pieces->push_back(",");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pieces->push_back(")");
|
|
||||||
}
|
|
||||||
pieces->push_back(" ");
|
pieces->push_back(" ");
|
||||||
}
|
}
|
||||||
std::vector<int64> indices = {};
|
std::vector<int64> indices = {};
|
||||||
std::vector<int64> dimensions;
|
std::vector<int64> dimensions(subshape.dimensions().begin(),
|
||||||
dimensions.reserve(subshape.rank());
|
subshape.dimensions().end());
|
||||||
for (int64 i = 0; i < subshape.rank(); ++i) {
|
|
||||||
dimensions.push_back(literal.GetDynamicSize(i, shape_index));
|
|
||||||
}
|
|
||||||
to_string_recursive(dimensions, &indices);
|
to_string_recursive(dimensions, &indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1527,44 +1374,13 @@ StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
|
|||||||
return literal;
|
return literal;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
|
||||||
void LiteralBase::Piece::CopyElementsWithDynamicBound(
|
|
||||||
const LiteralBase::Piece& src) {
|
|
||||||
auto dest_shape = subshape();
|
|
||||||
auto src_shape = src.subshape();
|
|
||||||
|
|
||||||
// At least one shape has to be static as bound.
|
|
||||||
CHECK(dest_shape.is_static() || src_shape.is_static());
|
|
||||||
auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
|
|
||||||
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::vector<int64> index(dest_shape.rank());
|
|
||||||
do {
|
|
||||||
bool out_of_bound = false;
|
|
||||||
for (int64 i = 0; i < index.size(); ++i) {
|
|
||||||
// Do not copy elements beyond dynamic bound.
|
|
||||||
if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
|
|
||||||
out_of_bound = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (out_of_bound) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
|
|
||||||
index)] =
|
|
||||||
src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
|
|
||||||
src_shape, index)];
|
|
||||||
} while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
bool LiteralBase::Piece::EqualElementsInternal(
|
bool LiteralBase::Piece::EqualElementsInternal(
|
||||||
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
|
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
|
||||||
if (multi_index->size() == subshape().rank()) {
|
if (multi_index->size() == subshape().rank()) {
|
||||||
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
|
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
|
||||||
}
|
}
|
||||||
for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
|
for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
|
||||||
multi_index->push_back(i);
|
multi_index->push_back(i);
|
||||||
if (!EqualElementsInternal<NativeT>(other, multi_index)) {
|
if (!EqualElementsInternal<NativeT>(other, multi_index)) {
|
||||||
return false;
|
return false;
|
||||||
@ -1574,26 +1390,10 @@ bool LiteralBase::Piece::EqualElementsInternal(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LiteralBase::Piece::EqualDynamicSize(
|
|
||||||
const LiteralBase::Piece& other) const {
|
|
||||||
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
|
|
||||||
if (subshape().is_static()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64 i = 0; i < subshape().rank(); ++i) {
|
|
||||||
if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
|
bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
|
||||||
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
|
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
|
||||||
|
|
||||||
if (subshape().is_static() &&
|
if (ShapeUtil::Equal(subshape(), other.subshape()) &&
|
||||||
ShapeUtil::Equal(subshape(), other.subshape()) &&
|
|
||||||
LayoutUtil::IsDenseArray(subshape())) {
|
LayoutUtil::IsDenseArray(subshape())) {
|
||||||
CHECK_EQ(size_bytes(), other.size_bytes());
|
CHECK_EQ(size_bytes(), other.size_bytes());
|
||||||
return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
|
return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
|
||||||
@ -1636,33 +1436,17 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool LiteralBase::operator==(const LiteralBase& other) const {
|
bool LiteralBase::operator==(const LiteralBase& other) const {
|
||||||
// Checking the structure of tuple literals. Checks for dense arrays are
|
if (!ShapeUtil::Compatible(shape(), other.shape())) {
|
||||||
// performed below.
|
|
||||||
if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return root_piece().ForEachSubpieceWithBool(
|
return root_piece().ForEachSubpieceWithBool(
|
||||||
[&](const ShapeIndex& index, const Piece& piece) {
|
[&](const ShapeIndex& index, const Piece& piece) {
|
||||||
const Piece& other_piece = other.piece(index);
|
|
||||||
const Shape& subshape = piece.subshape();
|
|
||||||
const Shape& other_subshape = other_piece.subshape();
|
|
||||||
if (subshape.element_type() != other_subshape.element_type()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!piece.subshape().IsArray()) {
|
if (!piece.subshape().IsArray()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (subshape.rank() != other_subshape.rank()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64 i = 0; i < subshape.rank(); ++i) {
|
|
||||||
if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const Piece& other_piece = other.piece(index);
|
||||||
if (!piece.EqualElements(other_piece)) {
|
if (!piece.EqualElements(other_piece)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -2251,7 +2035,6 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
|
|||||||
}
|
}
|
||||||
} else if (shape.IsArray()) {
|
} else if (shape.IsArray()) {
|
||||||
dest_piece->set_buffer(src_piece->buffer());
|
dest_piece->set_buffer(src_piece->buffer());
|
||||||
dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
|
|
||||||
} else {
|
} else {
|
||||||
// If the shape is neither an array nor tuple, then it must be
|
// If the shape is neither an array nor tuple, then it must be
|
||||||
// zero-sized. Otherwise, some memory needs to be allocated for it.
|
// zero-sized. Otherwise, some memory needs to be allocated for it.
|
||||||
|
@ -112,10 +112,6 @@ class LiteralBase {
|
|||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
NativeT Get(absl::Span<const int64> multi_index) const;
|
NativeT Get(absl::Span<const int64> multi_index) const;
|
||||||
|
|
||||||
// Get the dynamic size on dim_index in the literal at the given shape_index.
|
|
||||||
int32 GetDynamicSize(int64 dim_index, const ShapeIndex& shape_index) const;
|
|
||||||
int32 GetDynamicSize(int64 dim_index) const;
|
|
||||||
|
|
||||||
// Returns the element value at index (0, ..., 0), however many zeroes are
|
// Returns the element value at index (0, ..., 0), however many zeroes are
|
||||||
// required for that index.
|
// required for that index.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
@ -285,18 +281,6 @@ class LiteralBase {
|
|||||||
// than being limited to a single array within the shape.
|
// than being limited to a single array within the shape.
|
||||||
Literal Relayout(const Shape& shape_with_layout) const;
|
Literal Relayout(const Shape& shape_with_layout) const;
|
||||||
|
|
||||||
// Generate a new literal whose static sizes are equal to the previous
|
|
||||||
// literal's dynamic sizes.
|
|
||||||
Literal ToStatic() const;
|
|
||||||
|
|
||||||
// Expand a static literal into a new one with a bounded dyanmic literal. The
|
|
||||||
// static dimensions of the original literal becomes dynamic dimensions of the
|
|
||||||
// new literal, where the argument `bounded_shape` becomes the bounded shape
|
|
||||||
// of the new literal.
|
|
||||||
//
|
|
||||||
// Precondition: bounded_shape.is_dynamic()
|
|
||||||
Literal ToBoundedDynamic(const Shape& bounded_shape) const;
|
|
||||||
|
|
||||||
// Creates a new literal by reshaping this literal to have the given
|
// Creates a new literal by reshaping this literal to have the given
|
||||||
// dimensions. The total number of elements must not change; The
|
// dimensions. The total number of elements must not change; The
|
||||||
// implementation currently only supports monotonic dim0-major layouts.
|
// implementation currently only supports monotonic dim0-major layouts.
|
||||||
@ -370,22 +354,10 @@ class LiteralBase {
|
|||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
void Set(absl::Span<const int64> index, NativeT value);
|
void Set(absl::Span<const int64> index, NativeT value);
|
||||||
|
|
||||||
int32 GetDynamicSize(int64 dim_index) const;
|
|
||||||
void SetDynamicSize(int64 dim_index, int32 size);
|
|
||||||
// Gets/sets the buffer holding the array data.
|
// Gets/sets the buffer holding the array data.
|
||||||
char* buffer() const { return buffer_; }
|
char* buffer() const { return buffer_; }
|
||||||
void set_buffer(char* buffer) { buffer_ = buffer; }
|
void set_buffer(char* buffer) { buffer_ = buffer; }
|
||||||
|
|
||||||
// Gets/sets the buffer holding dynamic sizes.
|
|
||||||
int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
|
|
||||||
void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
|
|
||||||
dynamic_size_buffer_ = dynamic_size_buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 dynamic_size_buffer_bytes() const {
|
|
||||||
return subshape().dimensions_size() * sizeof(int32);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets or sets the subshape of this piece. This reference points to a
|
// Gets or sets the subshape of this piece. This reference points to a
|
||||||
// subshape within the shape in the containing Literal (Literal::shape_).
|
// subshape within the shape in the containing Literal (Literal::shape_).
|
||||||
const Shape& subshape() const { return *subshape_; }
|
const Shape& subshape() const { return *subshape_; }
|
||||||
@ -462,21 +434,15 @@ class LiteralBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if this piece and 'other' contain the same data. This piece
|
// Returns true if this piece and 'other' contain the same data. This piece
|
||||||
// and 'other' must be array-shaped and compatible. If a literal has dynamic
|
// and 'other' must be array-shaped and compatible.
|
||||||
// shape, comparison is done only for the valid elements.
|
|
||||||
bool EqualElements(const Piece& other) const;
|
bool EqualElements(const Piece& other) const;
|
||||||
|
|
||||||
// Returns true if this piece and other pieces have the same dynamic
|
|
||||||
// dimension sizes.
|
|
||||||
bool EqualDynamicSize(const Piece& other) const;
|
|
||||||
|
|
||||||
// Writes the shape and data (if array-shaped) into the given proto.
|
// Writes the shape and data (if array-shaped) into the given proto.
|
||||||
void WriteToProto(LiteralProto* proto) const;
|
void WriteToProto(LiteralProto* proto) const;
|
||||||
|
|
||||||
// Copy the data from 'src' into this piece's buffer. Shapes of this piece
|
// Copy the data from 'src' into this piece's buffer. Shapes of this piece
|
||||||
// and src must be compatible. If only_dynamic_bound is true, only elements
|
// and src must be compatible.
|
||||||
// within dynamic bounds will be copied.
|
Status CopyFrom(const Piece& src);
|
||||||
Status CopyFrom(const Piece& src, bool only_dynamic_bound);
|
|
||||||
|
|
||||||
// Copies the data from the given proto into this piece. The shape of this
|
// Copies the data from the given proto into this piece. The shape of this
|
||||||
// piece must be equal (not just compatible) to the shape of the proto.
|
// piece must be equal (not just compatible) to the shape of the proto.
|
||||||
@ -531,15 +497,9 @@ class LiteralBase {
|
|||||||
bool EqualElementsInternal(const Piece& other,
|
bool EqualElementsInternal(const Piece& other,
|
||||||
std::vector<int64>* multi_index) const;
|
std::vector<int64>* multi_index) const;
|
||||||
|
|
||||||
// Internal helper to copy elements from another given piece
|
|
||||||
template <typename NativeT>
|
|
||||||
void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
|
|
||||||
|
|
||||||
// For array-shaped pieces, this is the buffer holding the literal data.
|
// For array-shaped pieces, this is the buffer holding the literal data.
|
||||||
char* buffer_ = nullptr;
|
char* buffer_ = nullptr;
|
||||||
|
|
||||||
int32* dynamic_size_buffer_ = nullptr;
|
|
||||||
|
|
||||||
// The shape of piece. This points into the shape of the containing Literal
|
// The shape of piece. This points into the shape of the containing Literal
|
||||||
// (Literal::shape_).
|
// (Literal::shape_).
|
||||||
const Shape* subshape_ = nullptr;
|
const Shape* subshape_ = nullptr;
|
||||||
@ -590,11 +550,6 @@ class MutableLiteralBase : public LiteralBase {
|
|||||||
// mutate the shape as this can produce malformed Literals.
|
// mutate the shape as this can produce malformed Literals.
|
||||||
Shape* mutable_shape_do_not_use() { return shape_.get(); }
|
Shape* mutable_shape_do_not_use() { return shape_.get(); }
|
||||||
|
|
||||||
// Set the dynamic size on dim_index in the literal at the given shape_index.
|
|
||||||
void SetDynamicSize(int64 dim_index, const ShapeIndex& shape_index,
|
|
||||||
int32 size);
|
|
||||||
void SetDynamicSize(int64 dim_index, int32 size);
|
|
||||||
|
|
||||||
// Returns a pointer to the underlying buffer holding the array at the given
|
// Returns a pointer to the underlying buffer holding the array at the given
|
||||||
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
|
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
|
||||||
// is not array.
|
// is not array.
|
||||||
@ -605,12 +560,10 @@ class MutableLiteralBase : public LiteralBase {
|
|||||||
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
||||||
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
||||||
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
|
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
|
||||||
// rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
|
// rooted at 'src_shape_index', but need not be arrays.
|
||||||
// is true, only elements within dynamic bounds will be copied.
|
|
||||||
Status CopyFrom(const LiteralSlice& src_literal,
|
Status CopyFrom(const LiteralSlice& src_literal,
|
||||||
const ShapeIndex& dest_shape_index = {},
|
const ShapeIndex& dest_shape_index = {},
|
||||||
const ShapeIndex& src_shape_index = {},
|
const ShapeIndex& src_shape_index = {});
|
||||||
bool only_dynamic_bound = false);
|
|
||||||
|
|
||||||
// Copies the values from src_literal, starting at src_base shape indexes,
|
// Copies the values from src_literal, starting at src_base shape indexes,
|
||||||
// to this literal, starting at dest_base, where the copy size in each
|
// to this literal, starting at dest_base, where the copy size in each
|
||||||
@ -971,14 +924,9 @@ void LiteralBase::EachCell(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<int64> indices(shape().rank(), 0);
|
std::vector<int64> indices(shape().rank(), 0);
|
||||||
|
|
||||||
Shape shape_dynamic = shape();
|
|
||||||
for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
|
|
||||||
shape_dynamic.set_dimensions(i, GetDynamicSize(i));
|
|
||||||
}
|
|
||||||
do {
|
do {
|
||||||
per_cell(indices, Get<NativeT>(indices));
|
per_cell(indices, Get<NativeT>(indices));
|
||||||
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
|
} while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
|
@ -149,16 +149,6 @@ TEST_F(LiteralUtilTest, R2ToString) {
|
|||||||
EXPECT_EQ(expected, literal.ToString());
|
EXPECT_EQ(expected, literal.ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, R2DynamicToString) {
|
|
||||||
auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
|
|
||||||
literal.SetDynamicSize(0, {}, 2);
|
|
||||||
const string expected = R"(s32[<=3,2](2,2) {
|
|
||||||
{ 1, 2 },
|
|
||||||
{ 3, 4 }
|
|
||||||
})";
|
|
||||||
EXPECT_EQ(expected, literal.ToString());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, R3ToString) {
|
TEST_F(LiteralUtilTest, R3ToString) {
|
||||||
const auto literal =
|
const auto literal =
|
||||||
LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
|
LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
|
||||||
@ -431,28 +421,6 @@ TEST_F(LiteralUtilTest, TupleEquality) {
|
|||||||
EXPECT_NE(tuple1, different_tuple);
|
EXPECT_NE(tuple1, different_tuple);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, DynamicShapeEquality) {
|
|
||||||
// Test equality with tuples.
|
|
||||||
auto r1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
|
|
||||||
r1.SetDynamicSize(0, {}, 1);
|
|
||||||
auto r2 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
||||||
r2.SetDynamicSize(0, {}, 1);
|
|
||||||
auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2});
|
|
||||||
|
|
||||||
// Tuple with the same elements. One element is shared with the original
|
|
||||||
// tuple, the other is a clone of the element in the original tuple.
|
|
||||||
auto r1_clone = LiteralUtil::CreateR1<float>({1.0, 3.0});
|
|
||||||
r1_clone.SetDynamicSize(0, {}, 1);
|
|
||||||
auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2});
|
|
||||||
EXPECT_EQ(tuple1, tuple2);
|
|
||||||
|
|
||||||
// Tuple with different dynamic sizes.
|
|
||||||
auto r2_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
|
||||||
r2_clone.SetDynamicSize(0, {}, 2);
|
|
||||||
auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone});
|
|
||||||
EXPECT_NE(tuple1, tuple_3);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, C64Equality) {
|
TEST_F(LiteralUtilTest, C64Equality) {
|
||||||
// Test equality with tuples.
|
// Test equality with tuples.
|
||||||
auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
||||||
@ -724,47 +692,6 @@ TEST_F(LiteralUtilTest, TransposeR4) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, TransposeDynamicR2) {
|
|
||||||
// F32[2, <=3] (2, 1)
|
|
||||||
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
|
|
||||||
original.SetDynamicSize(1, 1);
|
|
||||||
// F32[<=3, 2] (1, 2)
|
|
||||||
auto reshape = original.Transpose(/*permutation=*/{1, 0});
|
|
||||||
|
|
||||||
reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
|
|
||||||
EXPECT_EQ(value, original.Get<float>({indices[1], indices[0]}));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, ToStaticR2) {
|
|
||||||
// F32[2, <=3] (2, 1)
|
|
||||||
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
|
|
||||||
original.SetDynamicSize(1, 1);
|
|
||||||
// F32[2, 1]
|
|
||||||
auto static_literal = original.ToStatic();
|
|
||||||
EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1}));
|
|
||||||
EXPECT_TRUE(static_literal.shape().is_static());
|
|
||||||
|
|
||||||
static_literal.EachCell<float>(
|
|
||||||
[&](absl::Span<const int64> indices, float value) {
|
|
||||||
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, ToBoundedDynamicR2) {
|
|
||||||
// F32[2, 1]
|
|
||||||
auto original = LiteralUtil::CreateR2<float>({{1}, {4}});
|
|
||||||
// F32[2, <=3] (2, 1)
|
|
||||||
auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true});
|
|
||||||
auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape);
|
|
||||||
EXPECT_EQ(dynamic_literal.shape(), dynamic_shape);
|
|
||||||
|
|
||||||
dynamic_literal.EachCell<float>(
|
|
||||||
[&](absl::Span<const int64> indices, float value) {
|
|
||||||
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
|
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
|
||||||
// Tests that using Relayout on an array is equivalent to creating it in the
|
// Tests that using Relayout on an array is equivalent to creating it in the
|
||||||
// target layout in the first place.
|
// target layout in the first place.
|
||||||
@ -870,38 +797,6 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) {
|
|||||||
EXPECT_EQ(input_2x3x2, result);
|
EXPECT_EQ(input_2x3x2, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, SliceR2Dynamic) {
|
|
||||||
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
||||||
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
||||||
input_3x4.SetDynamicSize(1, 3);
|
|
||||||
// slice second dim from dynamic size 3 to dynamic size 1.
|
|
||||||
auto result = input_3x4.Slice({0, 1}, {2, 2});
|
|
||||||
auto expected = LiteralUtil::CreateR2<uint32>({{2}, {6}});
|
|
||||||
EXPECT_EQ(expected, result);
|
|
||||||
EXPECT_EQ(result.GetDynamicSize(1), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, SliceR2DynamicInBound) {
|
|
||||||
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
||||||
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
||||||
input_3x4.SetDynamicSize(1, 1);
|
|
||||||
auto result = input_3x4.Slice({0, 0}, {2, 2});
|
|
||||||
auto expected = LiteralUtil::CreateR2<uint32>({{1}, {5}});
|
|
||||||
EXPECT_EQ(expected, result);
|
|
||||||
EXPECT_EQ(result.GetDynamicSize(1), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) {
|
|
||||||
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
|
|
||||||
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
|
|
||||||
input_3x4.SetDynamicSize(1, 1);
|
|
||||||
auto result = input_3x4.Slice({0, 1}, {2, 3});
|
|
||||||
auto expected = LiteralUtil::CreateR2<uint32>({{}, {}});
|
|
||||||
EXPECT_EQ(expected, result);
|
|
||||||
// Out of bound access clamps into 0 sized dimension.
|
|
||||||
EXPECT_EQ(result.GetDynamicSize(1), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, PopulateR1S64) {
|
TEST_F(LiteralUtilTest, PopulateR1S64) {
|
||||||
Literal output(ShapeUtil::MakeShape(S64, {1}));
|
Literal output(ShapeUtil::MakeShape(S64, {1}));
|
||||||
output.PopulateR1<int64>({77});
|
output.PopulateR1<int64>({77});
|
||||||
@ -1615,7 +1510,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) {
|
|||||||
EXPECT_EQ(u1, r[3]);
|
EXPECT_EQ(u1, r[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) {
|
TEST_F(LiteralUtilTest, LiteralSliceTest) {
|
||||||
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
auto scalar = LiteralUtil::CreateR0<float>(1.0);
|
||||||
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||||
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
|
||||||
@ -2078,17 +1973,6 @@ TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
|
|||||||
LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
|
LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, DynamicBroadcast) {
|
|
||||||
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
|
|
||||||
literal.SetDynamicSize(0, 1);
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
Literal broadcasted_literal,
|
|
||||||
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
|
|
||||||
/*dimensions=*/{1}));
|
|
||||||
EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2<int64>({{1}, {1}}));
|
|
||||||
EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, GetAsComplex128) {
|
TEST_F(LiteralUtilTest, GetAsComplex128) {
|
||||||
complex128 value = {1, 0};
|
complex128 value = {1, 0};
|
||||||
Literal c1 = LiteralUtil::CreateR0<complex128>(value);
|
Literal c1 = LiteralUtil::CreateR0<complex128>(value);
|
||||||
|
@ -440,10 +440,6 @@ Status HloEvaluator::HandleSetDimensionSize(
|
|||||||
Literal result(set_dimension_size->shape());
|
Literal result(set_dimension_size->shape());
|
||||||
memcpy(result.untyped_data(), operand_literal.untyped_data(),
|
memcpy(result.untyped_data(), operand_literal.untyped_data(),
|
||||||
operand_literal.size_bytes());
|
operand_literal.size_bytes());
|
||||||
const Literal& size_literal =
|
|
||||||
GetEvaluatedLiteralFor(set_dimension_size->operand(1));
|
|
||||||
result.SetDynamicSize(set_dimension_size->dimension(),
|
|
||||||
size_literal.Get<int32>({}));
|
|
||||||
evaluated_[set_dimension_size] = std::move(result);
|
evaluated_[set_dimension_size] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -81,17 +81,8 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
|||||||
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
for (int64 i = 0; i < computation->num_parameters(); ++i) {
|
||||||
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
const auto& expected_shape = computation->parameter_instruction(i)->shape();
|
||||||
const auto& actual_shape = argument_buffers[i].on_device_shape();
|
const auto& actual_shape = argument_buffers[i].on_device_shape();
|
||||||
bool shape_match = true;
|
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
||||||
if (expected_shape.is_dynamic()) {
|
actual_shape)) {
|
||||||
if (!ShapeUtil::DynamicArrayShapeIsCompatible(actual_shape,
|
|
||||||
expected_shape)) {
|
|
||||||
shape_match = false;
|
|
||||||
}
|
|
||||||
} else if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
|
|
||||||
actual_shape)) {
|
|
||||||
shape_match = false;
|
|
||||||
}
|
|
||||||
if (!shape_match) {
|
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Shape mismatch on parameter %d. Expected %s, but was %s.", i,
|
"Shape mismatch on parameter %d. Expected %s, but was %s.", i,
|
||||||
ShapeUtil::HumanStringWithLayout(expected_shape),
|
ShapeUtil::HumanStringWithLayout(expected_shape),
|
||||||
@ -109,18 +100,11 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
|||||||
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
TF_ASSIGN_OR_RETURN(Literal arg_literal,
|
||||||
transfer_manager->TransferLiteralFromDevice(
|
transfer_manager->TransferLiteralFromDevice(
|
||||||
run_options->stream(), argument_buffers[p]));
|
run_options->stream(), argument_buffers[p]));
|
||||||
const auto& expected_shape = computation->parameter_instruction(p)->shape();
|
|
||||||
if (expected_shape.is_dynamic()) {
|
|
||||||
// Expand the input literal to expected shape.
|
|
||||||
arg_literal = arg_literal.ToBoundedDynamic(expected_shape);
|
|
||||||
}
|
|
||||||
arg_literals.push_back(std::move(arg_literal));
|
arg_literals.push_back(std::move(arg_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(Literal result_literal,
|
TF_ASSIGN_OR_RETURN(Literal result_literal,
|
||||||
Evaluate(*computation, arg_literals));
|
Evaluate(*computation, arg_literals));
|
||||||
// Shrink the generated dynamic shape into static shape.
|
|
||||||
result_literal = result_literal.ToStatic();
|
|
||||||
|
|
||||||
// Transform the result literal back into a ShapedBuffer.
|
// Transform the result literal back into a ShapedBuffer.
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
|
||||||
|
@ -339,15 +339,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
|||||||
TF_DCHECK_OK(ValidateShape(*shape));
|
TF_DCHECK_OK(ValidateShape(*shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
|
|
||||||
const Shape& from) {
|
|
||||||
CHECK_EQ(to->rank(), from.rank());
|
|
||||||
for (int64 i = 0; i < from.rank(); ++i) {
|
|
||||||
to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
|
|
||||||
}
|
|
||||||
TF_DCHECK_OK(ValidateShape(*to));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
|
/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
|
||||||
return primitive_util::IsIntegralType(shape.element_type());
|
return primitive_util::IsIntegralType(shape.element_type());
|
||||||
}
|
}
|
||||||
|
@ -377,9 +377,6 @@ class ShapeUtil {
|
|||||||
// Appends a major dimension to the shape with the given bound.
|
// Appends a major dimension to the shape with the given bound.
|
||||||
static void AppendMajorDimension(int bound, Shape* shape);
|
static void AppendMajorDimension(int bound, Shape* shape);
|
||||||
|
|
||||||
// Copy the dynamic dimensions property from one shape to another.
|
|
||||||
static void CopyDynamicDimensions(Shape* to, const Shape& from);
|
|
||||||
|
|
||||||
// Returns an empty tuple shape. Can be used as a sentinel Shape value.
|
// Returns an empty tuple shape. Can be used as a sentinel Shape value.
|
||||||
static Shape MakeNil() { return MakeTupleShape({}); }
|
static Shape MakeNil() { return MakeTupleShape({}); }
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user