Dynamic literal support

PiperOrigin-RevId: 321836977
Change-Id: Ib5524846de424da20643c6982f6454614d9ffa07
This commit is contained in:
A. Unique TensorFlower 2020-07-17 12:55:18 -07:00 committed by TensorFlower Gardener
parent c49b3f570d
commit ce190ec244
7 changed files with 34 additions and 451 deletions

View File

@ -48,10 +48,6 @@ namespace {
using absl::StrCat;
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
// Literals can be used as DMA targets, which can require alignment. We
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
// alignment.
constexpr int kMinimumAlignment = 64;
// Converts between little and big endian.
//
@ -137,14 +133,12 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
}
} else if (shape.IsArray()) {
if (allocate_arrays) {
// Literals can be used as DMA targets, which can require alignment. We
// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum
// alignment.
constexpr int kMinimumAlignment = 64;
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
piece->size_bytes(), kMinimumAlignment)));
if (shape.is_dynamic()) {
CHECK_EQ(piece->dynamic_size_buffer(), nullptr);
piece->set_dynamic_size_buffer(
static_cast<int32*>(tensorflow::port::AlignedMalloc(
piece->dynamic_size_buffer_bytes(), kMinimumAlignment)));
}
}
} else {
// If the shape is neither an array nor tuple, then it must be
@ -177,9 +171,6 @@ void Literal::DeallocateBuffers() {
if (piece->buffer() != nullptr) {
tensorflow::port::AlignedFree(piece->buffer());
}
if (piece->dynamic_size_buffer() != nullptr) {
tensorflow::port::AlignedFree(piece->dynamic_size_buffer());
}
});
}
@ -208,15 +199,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
return literal;
}
int32 LiteralBase::GetDynamicSize(int64 dim_index) const {
return GetDynamicSize(dim_index, {});
}
int32 LiteralBase::GetDynamicSize(int64 dim_index,
const ShapeIndex& shape_index) const {
return piece(shape_index).GetDynamicSize(dim_index);
}
absl::optional<int64> LiteralBase::GetFirstInteger() const {
switch (shape().element_type()) {
case U8:
@ -399,9 +381,7 @@ std::vector<Literal> Literal::DecomposeTuple() {
// Move the respective buffer over to the element Literal.
dest_piece->set_buffer(src_piece.buffer());
dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
src_piece.set_buffer(nullptr);
src_piece.set_dynamic_size_buffer(nullptr);
});
}
// Set this literal to be nil-shaped.
@ -427,51 +407,23 @@ void CopyElementsBetween(absl::Span<NativeT> dest,
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
} while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
}
} // namespace
int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const {
CHECK(LayoutUtil::IsDenseArray(subshape()));
if (!subshape_->is_dynamic_dimension(dim_index)) {
// This is a static dimension, return size.
return subshape_->dimensions(dim_index);
}
CHECK_NE(dynamic_size_buffer(), nullptr);
return dynamic_size_buffer_[dim_index];
}
void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) {
CHECK(LayoutUtil::IsDenseArray(subshape()));
CHECK(subshape_->is_dynamic_dimension(dim_index));
if (dynamic_size_buffer() == nullptr) {
// Lazily initialize the dynamic size buffer.
set_dynamic_size_buffer(static_cast<int32*>(tensorflow::port::AlignedMalloc(
dynamic_size_buffer_bytes(), kMinimumAlignment)));
/*for (int64 i = 0; i < subshape().rank(); ++i) {
// Initialized to -1 to help debug.
dynamic_size_buffer_[i] = -1;
}*/
}
dynamic_size_buffer_[dim_index] = size;
}
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
bool only_dynamic_bound) {
Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
CHECK(subshape_ != nullptr);
CHECK(src.subshape_ != nullptr);
if (ShapeUtil::Equal(subshape(), src.subshape())) {
// If the layouts are equal it's faster just to memcpy.
memcpy(buffer(), src.buffer(), src.size_bytes());
} else {
TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
std::vector<int64> origin(subshape().rank(), 0);
switch (subshape().element_type()) {
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
case (XLA_T): \
if (only_dynamic_bound) { \
CopyElementsWithDynamicBound<NATIVE_T>(src); \
} else { \
CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
subshape(), src.subshape()); \
} \
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
case (XLA_T): \
CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
subshape(), src.subshape()); \
break;
COPY_ELEMENTS(U8, uint8);
COPY_ELEMENTS(U16, uint16);
@ -495,54 +447,21 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src,
PrimitiveType_Name(subshape().element_type()));
}
}
DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes());
if (subshape().is_dynamic() && src.subshape().is_dynamic()) {
CHECK_NE(dynamic_size_buffer_, nullptr);
CHECK_NE(src.dynamic_size_buffer_, nullptr);
memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(),
src.dynamic_size_buffer_bytes());
}
return Status::OK();
}
void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) {
return SetDynamicSize(dim_index, {}, size);
}
void MutableLiteralBase::SetDynamicSize(int64 dim_index,
const ShapeIndex& shape_index,
int32 size) {
Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index);
CHECK_GE(subshape_->dimensions(dim_index), size);
if (subshape_->dimensions(dim_index) == size) {
subshape_->set_dynamic_dimension(dim_index, false);
return;
}
subshape_->set_dynamic_dimension(dim_index, true);
piece(shape_index).SetDynamicSize(dim_index, size);
}
Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index,
bool only_dynamic_bound) {
const ShapeIndex& src_shape_index) {
const Shape& dest_subshape =
ShapeUtil::GetSubshape(shape(), dest_shape_index);
const Shape& src_subshape =
ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
if (only_dynamic_bound) {
auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape;
auto compact_shape =
dest_subshape.is_static() ? dest_subshape : src_subshape;
CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape))
<< compact_shape.ToString() << " vs " << bound_shape.ToString();
} else {
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
return InvalidArgument(
"Destination subshape incompatible with source subshape: %s vs %s",
ShapeUtil::HumanString(dest_subshape),
ShapeUtil::HumanString(src_subshape));
}
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
return InvalidArgument(
"Destination subshape incompatible with source subshape: %s vs %s",
ShapeUtil::HumanString(dest_subshape),
ShapeUtil::HumanString(src_subshape));
}
return root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
@ -567,9 +486,7 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
src_piece_index.push_back(index[i]);
}
TF_RETURN_IF_ERROR(
piece->CopyFrom(src_literal.piece(src_piece_index),
/*only_dynamic_bound=*/only_dynamic_bound));
TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
return Status::OK();
});
}
@ -597,9 +514,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
}
Piece& dest_piece = piece(dest_index);
tensorflow::port::AlignedFree(dest_piece.buffer());
tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer());
dest_piece.set_buffer(src_piece.buffer());
dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer());
});
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
@ -714,41 +629,6 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
return result;
}
Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const {
CHECK(bounded_shape.is_dynamic());
Literal result(bounded_shape);
ShapeUtil::ForEachSubshape(
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!subshape.IsArray()) {
return;
}
for (int64 i = 0; i < subshape.rank(); ++i) {
result.SetDynamicSize(i, subshape.dimensions(i));
}
});
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
return result;
}
Literal LiteralBase::ToStatic() const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
ShapeUtil::ForEachMutableSubshape(
&new_shape, [this](Shape* subshape, const ShapeIndex& index) {
if (!subshape->IsArray()) {
return;
}
for (int64 i = 0; i < subshape->rank(); ++i) {
subshape->set_dynamic_dimension(i, false);
subshape->set_dimensions(i, GetDynamicSize(i, index));
}
});
Literal result(new_shape);
TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true));
return result;
}
StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!shape().IsArray()) {
@ -772,11 +652,6 @@ StatusOr<Literal> LiteralBase::Broadcast(
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
for (int64 i = 0; i < dimensions.size(); ++i) {
int64 dynamic_size = GetDynamicSize(i);
result.SetDynamicSize(dimensions[i], dynamic_size);
}
ShapeUtil::ForEachIndex(
result_shape, [&](absl::Span<const int64> output_index) {
for (int64 i = 0; i < dimensions.size(); ++i) {
@ -799,9 +674,6 @@ StatusOr<Literal> LiteralBase::Reshape(
if (!shape().IsArray()) {
return InvalidArgument("Reshape does not support tuples.");
}
if (shape().is_dynamic()) {
return Unimplemented("Dynamic reshape is not implemented.");
}
Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
@ -856,9 +728,6 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
layout->add_minor_to_major(inverse_permutation[index]);
}
Literal new_literal(permuted_shape);
for (int64 i = 0; i < shape().rank(); i++) {
new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i));
}
DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
@ -878,14 +747,6 @@ Literal LiteralBase::SliceInternal(
return Get<NativeT>(new_indices);
})
.ok());
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
if (shape().is_dynamic_dimension(dnum)) {
int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum];
CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum);
dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum));
result_literal.SetDynamicSize(dnum, dynamic_size);
}
}
return result_literal;
}
@ -902,10 +763,9 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
CHECK_GE(dimension, 0) << "dnum = " << dnum;
result_dimensions.push_back(dimension);
}
auto result_shape =
const auto result_shape =
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
ShapeUtil::CopyDynamicDimensions(&result_shape, shape());
switch (result_shape.element_type()) {
case PRED:
return SliceInternal<bool>(result_shape, start_indices);
@ -1222,24 +1082,11 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
if (print_shape) {
pieces->push_back(ShapeToString(print_layout, subshape));
if (subshape.is_dynamic()) {
pieces->push_back("(");
for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index)));
if (i < subshape.dimensions_size() - 1) {
pieces->push_back(",");
}
}
pieces->push_back(")");
}
pieces->push_back(" ");
}
std::vector<int64> indices = {};
std::vector<int64> dimensions;
dimensions.reserve(subshape.rank());
for (int64 i = 0; i < subshape.rank(); ++i) {
dimensions.push_back(literal.GetDynamicSize(i, shape_index));
}
std::vector<int64> dimensions(subshape.dimensions().begin(),
subshape.dimensions().end());
to_string_recursive(dimensions, &indices);
}
@ -1527,44 +1374,13 @@ StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
return literal;
}
template <typename NativeT>
void LiteralBase::Piece::CopyElementsWithDynamicBound(
const LiteralBase::Piece& src) {
auto dest_shape = subshape();
auto src_shape = src.subshape();
// At least one shape has to be static as bound.
CHECK(dest_shape.is_static() || src_shape.is_static());
auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape;
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(dest_shape.rank());
do {
bool out_of_bound = false;
for (int64 i = 0; i < index.size(); ++i) {
// Do not copy elements beyond dynamic bound.
if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) {
out_of_bound = true;
}
}
if (out_of_bound) {
continue;
}
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape,
index)] =
src.data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
src_shape, index)];
} while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index)));
}
template <typename NativeT>
bool LiteralBase::Piece::EqualElementsInternal(
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
if (multi_index->size() == subshape().rank()) {
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) {
for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
multi_index->push_back(i);
if (!EqualElementsInternal<NativeT>(other, multi_index)) {
return false;
@ -1574,26 +1390,10 @@ bool LiteralBase::Piece::EqualElementsInternal(
return true;
}
bool LiteralBase::Piece::EqualDynamicSize(
const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
if (subshape().is_static()) {
return true;
}
for (int64 i = 0; i < subshape().rank(); ++i) {
if (GetDynamicSize(i) != other.GetDynamicSize(i)) {
return false;
}
}
return true;
}
bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
if (subshape().is_static() &&
ShapeUtil::Equal(subshape(), other.subshape()) &&
if (ShapeUtil::Equal(subshape(), other.subshape()) &&
LayoutUtil::IsDenseArray(subshape())) {
CHECK_EQ(size_bytes(), other.size_bytes());
return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
@ -1636,33 +1436,17 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
}
bool LiteralBase::operator==(const LiteralBase& other) const {
// Checking the structure of tuple literals. Checks for dense arrays are
// performed below.
if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
if (!ShapeUtil::Compatible(shape(), other.shape())) {
return false;
}
return root_piece().ForEachSubpieceWithBool(
[&](const ShapeIndex& index, const Piece& piece) {
const Piece& other_piece = other.piece(index);
const Shape& subshape = piece.subshape();
const Shape& other_subshape = other_piece.subshape();
if (subshape.element_type() != other_subshape.element_type()) {
return false;
}
if (!piece.subshape().IsArray()) {
return true;
}
if (subshape.rank() != other_subshape.rank()) {
return false;
}
for (int64 i = 0; i < subshape.rank(); ++i) {
if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
return false;
}
}
const Piece& other_piece = other.piece(index);
if (!piece.EqualElements(other_piece)) {
return false;
}
@ -2251,7 +2035,6 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
}
} else if (shape.IsArray()) {
dest_piece->set_buffer(src_piece->buffer());
dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer());
} else {
// If the shape is neither an array nor tuple, then it must be
// zero-sized. Otherwise, some memory needs to be allocated for it.

View File

@ -112,10 +112,6 @@ class LiteralBase {
template <typename NativeT>
NativeT Get(absl::Span<const int64> multi_index) const;
// Get the dynamic size on dim_index in the literal at the given shape_index.
int32 GetDynamicSize(int64 dim_index, const ShapeIndex& shape_index) const;
int32 GetDynamicSize(int64 dim_index) const;
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
template <typename NativeT>
@ -285,18 +281,6 @@ class LiteralBase {
// than being limited to a single array within the shape.
Literal Relayout(const Shape& shape_with_layout) const;
// Generate a new literal whose static sizes are equal to the previous
// literal's dynamic sizes.
Literal ToStatic() const;
// Expand a static literal into a new one with a bounded dyanmic literal. The
// static dimensions of the original literal becomes dynamic dimensions of the
// new literal, where the argument `bounded_shape` becomes the bounded shape
// of the new literal.
//
// Precondition: bounded_shape.is_dynamic()
Literal ToBoundedDynamic(const Shape& bounded_shape) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
@ -370,22 +354,10 @@ class LiteralBase {
template <typename NativeT>
void Set(absl::Span<const int64> index, NativeT value);
int32 GetDynamicSize(int64 dim_index) const;
void SetDynamicSize(int64 dim_index, int32 size);
// Gets/sets the buffer holding the array data.
char* buffer() const { return buffer_; }
void set_buffer(char* buffer) { buffer_ = buffer; }
// Gets/sets the buffer holding dynamic sizes.
int32* dynamic_size_buffer() const { return dynamic_size_buffer_; }
void set_dynamic_size_buffer(int32* dynamic_size_buffer) {
dynamic_size_buffer_ = dynamic_size_buffer;
}
int64 dynamic_size_buffer_bytes() const {
return subshape().dimensions_size() * sizeof(int32);
}
// Gets or sets the subshape of this piece. This reference points to a
// subshape within the shape in the containing Literal (Literal::shape_).
const Shape& subshape() const { return *subshape_; }
@ -462,21 +434,15 @@ class LiteralBase {
}
// Returns true if this piece and 'other' contain the same data. This piece
// and 'other' must be array-shaped and compatible. If a literal has dynamic
// shape, comparison is done only for the valid elements.
// and 'other' must be array-shaped and compatible.
bool EqualElements(const Piece& other) const;
// Returns true if this piece and other pieces have the same dynamic
// dimension sizes.
bool EqualDynamicSize(const Piece& other) const;
// Writes the shape and data (if array-shaped) into the given proto.
void WriteToProto(LiteralProto* proto) const;
// Copy the data from 'src' into this piece's buffer. Shapes of this piece
// and src must be compatible. If only_dynamic_bound is true, only elements
// within dynamic bounds will be copied.
Status CopyFrom(const Piece& src, bool only_dynamic_bound);
// and src must be compatible.
Status CopyFrom(const Piece& src);
// Copies the data from the given proto into this piece. The shape of this
// piece must be equal (not just compatible) to the shape of the proto.
@ -531,15 +497,9 @@ class LiteralBase {
bool EqualElementsInternal(const Piece& other,
std::vector<int64>* multi_index) const;
// Internal helper to copy elements from another given piece
template <typename NativeT>
void CopyElementsWithDynamicBound(const LiteralBase::Piece& src);
// For array-shaped pieces, this is the buffer holding the literal data.
char* buffer_ = nullptr;
int32* dynamic_size_buffer_ = nullptr;
// The shape of piece. This points into the shape of the containing Literal
// (Literal::shape_).
const Shape* subshape_ = nullptr;
@ -590,11 +550,6 @@ class MutableLiteralBase : public LiteralBase {
// mutate the shape as this can produce malformed Literals.
Shape* mutable_shape_do_not_use() { return shape_.get(); }
// Set the dynamic size on dim_index in the literal at the given shape_index.
void SetDynamicSize(int64 dim_index, const ShapeIndex& shape_index,
int32 size);
void SetDynamicSize(int64 dim_index, int32 size);
// Returns a pointer to the underlying buffer holding the array at the given
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
// is not array.
@ -605,12 +560,10 @@ class MutableLiteralBase : public LiteralBase {
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
// rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound
// is true, only elements within dynamic bounds will be copied.
// rooted at 'src_shape_index', but need not be arrays.
Status CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index = {},
const ShapeIndex& src_shape_index = {},
bool only_dynamic_bound = false);
const ShapeIndex& src_shape_index = {});
// Copies the values from src_literal, starting at src_base shape indexes,
// to this literal, starting at dest_base, where the copy size in each
@ -971,14 +924,9 @@ void LiteralBase::EachCell(
return;
}
std::vector<int64> indices(shape().rank(), 0);
Shape shape_dynamic = shape();
for (int64 i = 0; i < shape_dynamic.rank(); ++i) {
shape_dynamic.set_dimensions(i, GetDynamicSize(i));
}
do {
per_cell(indices, Get<NativeT>(indices));
} while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices)));
} while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
template <typename NativeT>

View File

@ -149,16 +149,6 @@ TEST_F(LiteralUtilTest, R2ToString) {
EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R2DynamicToString) {
auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
literal.SetDynamicSize(0, {}, 2);
const string expected = R"(s32[<=3,2](2,2) {
{ 1, 2 },
{ 3, 4 }
})";
EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R3ToString) {
const auto literal =
LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
@ -431,28 +421,6 @@ TEST_F(LiteralUtilTest, TupleEquality) {
EXPECT_NE(tuple1, different_tuple);
}
TEST_F(LiteralUtilTest, DynamicShapeEquality) {
// Test equality with tuples.
auto r1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
r1.SetDynamicSize(0, {}, 1);
auto r2 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
r2.SetDynamicSize(0, {}, 1);
auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto r1_clone = LiteralUtil::CreateR1<float>({1.0, 3.0});
r1_clone.SetDynamicSize(0, {}, 1);
auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2});
EXPECT_EQ(tuple1, tuple2);
// Tuple with different dynamic sizes.
auto r2_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
r2_clone.SetDynamicSize(0, {}, 2);
auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone});
EXPECT_NE(tuple1, tuple_3);
}
TEST_F(LiteralUtilTest, C64Equality) {
// Test equality with tuples.
auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
@ -724,47 +692,6 @@ TEST_F(LiteralUtilTest, TransposeR4) {
});
}
TEST_F(LiteralUtilTest, TransposeDynamicR2) {
// F32[2, <=3] (2, 1)
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
original.SetDynamicSize(1, 1);
// F32[<=3, 2] (1, 2)
auto reshape = original.Transpose(/*permutation=*/{1, 0});
reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
EXPECT_EQ(value, original.Get<float>({indices[1], indices[0]}));
});
}
TEST_F(LiteralUtilTest, ToStaticR2) {
// F32[2, <=3] (2, 1)
auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
original.SetDynamicSize(1, 1);
// F32[2, 1]
auto static_literal = original.ToStatic();
EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1}));
EXPECT_TRUE(static_literal.shape().is_static());
static_literal.EachCell<float>(
[&](absl::Span<const int64> indices, float value) {
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
});
}
TEST_F(LiteralUtilTest, ToBoundedDynamicR2) {
// F32[2, 1]
auto original = LiteralUtil::CreateR2<float>({{1}, {4}});
// F32[2, <=3] (2, 1)
auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true});
auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape);
EXPECT_EQ(dynamic_literal.shape(), dynamic_shape);
dynamic_literal.EachCell<float>(
[&](absl::Span<const int64> indices, float value) {
EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
});
}
TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// Tests that using Relayout on an array is equivalent to creating it in the
// target layout in the first place.
@ -870,38 +797,6 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) {
EXPECT_EQ(input_2x3x2, result);
}
TEST_F(LiteralUtilTest, SliceR2Dynamic) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
input_3x4.SetDynamicSize(1, 3);
// slice second dim from dynamic size 3 to dynamic size 1.
auto result = input_3x4.Slice({0, 1}, {2, 2});
auto expected = LiteralUtil::CreateR2<uint32>({{2}, {6}});
EXPECT_EQ(expected, result);
EXPECT_EQ(result.GetDynamicSize(1), 1);
}
TEST_F(LiteralUtilTest, SliceR2DynamicInBound) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
input_3x4.SetDynamicSize(1, 1);
auto result = input_3x4.Slice({0, 0}, {2, 2});
auto expected = LiteralUtil::CreateR2<uint32>({{1}, {5}});
EXPECT_EQ(expected, result);
EXPECT_EQ(result.GetDynamicSize(1), 1);
}
TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
input_3x4.SetDynamicSize(1, 1);
auto result = input_3x4.Slice({0, 1}, {2, 3});
auto expected = LiteralUtil::CreateR2<uint32>({{}, {}});
EXPECT_EQ(expected, result);
// Out of bound access clamps into 0 sized dimension.
EXPECT_EQ(result.GetDynamicSize(1), 0);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
@ -1615,7 +1510,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) {
EXPECT_EQ(u1, r[3]);
}
TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) {
TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
@ -2078,17 +1973,6 @@ TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
TEST_F(LiteralUtilTest, DynamicBroadcast) {
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
literal.SetDynamicSize(0, 1);
TF_ASSERT_OK_AND_ASSIGN(
Literal broadcasted_literal,
literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
/*dimensions=*/{1}));
EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2<int64>({{1}, {1}}));
EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1);
}
TEST_F(LiteralUtilTest, GetAsComplex128) {
complex128 value = {1, 0};
Literal c1 = LiteralUtil::CreateR0<complex128>(value);

View File

@ -440,10 +440,6 @@ Status HloEvaluator::HandleSetDimensionSize(
Literal result(set_dimension_size->shape());
memcpy(result.untyped_data(), operand_literal.untyped_data(),
operand_literal.size_bytes());
const Literal& size_literal =
GetEvaluatedLiteralFor(set_dimension_size->operand(1));
result.SetDynamicSize(set_dimension_size->dimension(),
size_literal.Get<int32>({}));
evaluated_[set_dimension_size] = std::move(result);
return Status::OK();
}

View File

@ -81,17 +81,8 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
for (int64 i = 0; i < computation->num_parameters(); ++i) {
const auto& expected_shape = computation->parameter_instruction(i)->shape();
const auto& actual_shape = argument_buffers[i].on_device_shape();
bool shape_match = true;
if (expected_shape.is_dynamic()) {
if (!ShapeUtil::DynamicArrayShapeIsCompatible(actual_shape,
expected_shape)) {
shape_match = false;
}
} else if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
actual_shape)) {
shape_match = false;
}
if (!shape_match) {
if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape,
actual_shape)) {
return InvalidArgument(
"Shape mismatch on parameter %d. Expected %s, but was %s.", i,
ShapeUtil::HumanStringWithLayout(expected_shape),
@ -109,18 +100,11 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), argument_buffers[p]));
const auto& expected_shape = computation->parameter_instruction(p)->shape();
if (expected_shape.is_dynamic()) {
// Expand the input literal to expected shape.
arg_literal = arg_literal.ToBoundedDynamic(expected_shape);
}
arg_literals.push_back(std::move(arg_literal));
}
TF_ASSIGN_OR_RETURN(Literal result_literal,
Evaluate(*computation, arg_literals));
// Shrink the generated dynamic shape into static shape.
result_literal = result_literal.ToStatic();
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,

View File

@ -339,15 +339,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
TF_DCHECK_OK(ValidateShape(*shape));
}
/* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
const Shape& from) {
CHECK_EQ(to->rank(), from.rank());
for (int64 i = 0; i < from.rank(); ++i) {
to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
}
TF_DCHECK_OK(ValidateShape(*to));
}
/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
return primitive_util::IsIntegralType(shape.element_type());
}

View File

@ -377,9 +377,6 @@ class ShapeUtil {
// Appends a major dimension to the shape with the given bound.
static void AppendMajorDimension(int bound, Shape* shape);
// Copy the dynamic dimensions property from one shape to another.
static void CopyDynamicDimensions(Shape* to, const Shape& from);
// Returns an empty tuple shape. Can be used as a sentinel Shape value.
static Shape MakeNil() { return MakeTupleShape({}); }