[XLA] Remove unsupported sparse layout
Sparse layouts are not supported on any of the backends. For backwards compatibility the fields stay in the protobuf, but parsing them is a no-op. PiperOrigin-RevId: 287924498 Change-Id: I8b1c1ec52e3a423015837bc10deee832921ba66c
This commit is contained in:
parent
1a416ed6a5
commit
2c431b6169
@ -417,7 +417,6 @@ cc_library(
|
||||
":array3d",
|
||||
":array4d",
|
||||
":shape_util",
|
||||
":sparse_index_array",
|
||||
":status_macros",
|
||||
":types",
|
||||
":util",
|
||||
@ -463,7 +462,6 @@ cc_library(
|
||||
":array4d",
|
||||
":literal",
|
||||
":shape_util",
|
||||
":sparse_index_array",
|
||||
":status_macros",
|
||||
":types",
|
||||
":util",
|
||||
@ -840,29 +838,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sparse_index_array",
|
||||
srcs = ["sparse_index_array.cc"],
|
||||
hdrs = ["sparse_index_array.h"],
|
||||
deps = [
|
||||
":array2d",
|
||||
":shape_util",
|
||||
":xla_data_proto_cc",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sparse_index_array_test",
|
||||
srcs = ["sparse_index_array_test.cc"],
|
||||
deps = [
|
||||
":sparse_index_array",
|
||||
":test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parse_flags_from_env",
|
||||
srcs = ["parse_flags_from_env.cc"],
|
||||
|
@ -52,7 +52,6 @@ string Tile::ToString() const {
|
||||
for (const int64 dimension : proto.minor_to_major()) {
|
||||
layout.add_minor_to_major(dimension);
|
||||
}
|
||||
layout.set_max_sparse_elements(proto.max_sparse_elements());
|
||||
for (const TileProto& tile_proto : proto.tiles()) {
|
||||
*layout.add_tiles() = Tile::CreateFromProto(tile_proto);
|
||||
}
|
||||
@ -68,7 +67,6 @@ LayoutProto Layout::ToProto() const {
|
||||
for (const int64 dimension : minor_to_major()) {
|
||||
proto.add_minor_to_major(dimension);
|
||||
}
|
||||
proto.set_max_sparse_elements(max_sparse_elements_);
|
||||
for (const Tile& tile : tiles()) {
|
||||
*proto.add_tiles() = tile.ToProto();
|
||||
}
|
||||
@ -78,10 +76,7 @@ LayoutProto Layout::ToProto() const {
|
||||
}
|
||||
|
||||
string Layout::ToString() const {
|
||||
if (format() == SPARSE) {
|
||||
CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled.";
|
||||
return absl::StrCat("sparse{", max_sparse_elements(), "}");
|
||||
} else if (format() == DENSE) {
|
||||
if (format() == DENSE) {
|
||||
string colon_string = tiles().empty() ? "" : "T";
|
||||
for (Tile tile : tiles()) {
|
||||
absl::StrAppend(&colon_string, tile.ToString());
|
||||
@ -107,10 +102,6 @@ bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
|
||||
if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) {
|
||||
return false;
|
||||
}
|
||||
if (lhs.format() == SPARSE &&
|
||||
lhs.max_sparse_elements() != rhs.max_sparse_elements()) {
|
||||
return false;
|
||||
}
|
||||
if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -203,12 +203,6 @@ class Layout {
|
||||
absl::Span<const Tile> tiles() const { return tiles_; }
|
||||
absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; }
|
||||
|
||||
// Methods for accessing the int64 fields.
|
||||
int64 max_sparse_elements() const { return max_sparse_elements_; }
|
||||
Layout& set_max_sparse_elements(int64 value) {
|
||||
max_sparse_elements_ = value;
|
||||
return *this;
|
||||
}
|
||||
int64 element_size_in_bits() const { return element_size_in_bits_; }
|
||||
Layout& set_element_size_in_bits(int64 value) {
|
||||
element_size_in_bits_ = value;
|
||||
@ -233,8 +227,7 @@ class Layout {
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Layout& l) {
|
||||
return H::combine(std::move(h), l.format_, l.minor_to_major_,
|
||||
l.max_sparse_elements_, l.tiles_,
|
||||
return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_,
|
||||
l.element_size_in_bits_);
|
||||
}
|
||||
|
||||
@ -255,11 +248,6 @@ class Layout {
|
||||
// And the major dim is [8,100,100,3][1], which is size 100.
|
||||
absl::InlinedVector<int64, 6> minor_to_major_;
|
||||
|
||||
// The maximum number of elements that can be stored for SPARSE formats. This
|
||||
// can be used to determine the maximum size in bytes of arrays stored in
|
||||
// memory. This field must be zero unless the format is SPARSE.
|
||||
int64 max_sparse_elements_ = 0;
|
||||
|
||||
// The tiles used in tiling-based layout.
|
||||
absl::InlinedVector<Tile, 2> tiles_;
|
||||
|
||||
|
@ -34,8 +34,6 @@ class LayoutTest : public ::testing::Test {};
|
||||
TEST_F(LayoutTest, ToString) {
|
||||
EXPECT_EQ(Layout().ToString(), "invalid{}");
|
||||
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
|
||||
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(),
|
||||
"sparse{123}");
|
||||
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
|
||||
EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(),
|
||||
"{3,2,1,0:T(42,123)(4,5)}");
|
||||
@ -65,11 +63,6 @@ TEST_F(LayoutTest, StreamOut) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, SparseLayoutMaxElements) {
|
||||
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
|
||||
101);
|
||||
}
|
||||
|
||||
TEST_F(LayoutTest, Equality) {
|
||||
EXPECT_EQ(Layout(), Layout());
|
||||
const std::vector<int64> empty_dims;
|
||||
@ -90,12 +83,6 @@ TEST_F(LayoutTest, Equality) {
|
||||
Layout({0, 1, 2}).set_memory_space(3));
|
||||
EXPECT_NE(Layout({0, 1, 2}).set_memory_space(1),
|
||||
Layout({0, 1, 2}).set_memory_space(3));
|
||||
EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE));
|
||||
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42),
|
||||
Layout().set_format(SPARSE).set_max_sparse_elements(42));
|
||||
EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42),
|
||||
Layout().set_format(SPARSE).set_max_sparse_elements(24));
|
||||
|
||||
EXPECT_FALSE(
|
||||
Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2})));
|
||||
EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}),
|
||||
@ -117,8 +104,6 @@ TEST_F(LayoutTest, LayoutToFromProto) {
|
||||
|
||||
expect_unchanged(Layout());
|
||||
expect_unchanged(Layout({1, 3, 2, 0}));
|
||||
expect_unchanged(Layout().set_format(SPARSE));
|
||||
expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123));
|
||||
expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42));
|
||||
expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}));
|
||||
}
|
||||
|
@ -94,13 +94,6 @@ void SetDefaultLayoutToContainer(T* minor_to_major) {
|
||||
return layout;
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
|
||||
Layout layout;
|
||||
layout.set_format(SPARSE);
|
||||
layout.set_max_sparse_elements(max_sparse_elements);
|
||||
return layout;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Internal helper that creates a default layout for an array of the given rank.
|
||||
@ -293,19 +286,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
layout.minor_to_major().end(), std::greater<int64>());
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
|
||||
return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
|
||||
return layout.format() == SPARSE;
|
||||
}
|
||||
|
||||
/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
|
||||
CHECK(IsSparse(layout));
|
||||
return layout.max_sparse_elements();
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
|
||||
if (shape.IsTuple()) {
|
||||
// Tuple shape: all subshapes must have a layout.
|
||||
@ -461,8 +441,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
for (int64 minor_to_major : layout.minor_to_major()) {
|
||||
hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
|
||||
}
|
||||
hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
|
||||
|
||||
for (Tile tile : layout.tiles()) {
|
||||
for (int64 tile_dim : tile.dimensions()) {
|
||||
hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));
|
||||
|
@ -49,10 +49,6 @@ class LayoutUtil {
|
||||
// dimensions.
|
||||
static Layout MakeDescendingLayout(int64 rank);
|
||||
|
||||
// Creates a sparse layout with the given maximum number of elements. (This is
|
||||
// a convenience function for protobuf construction.)
|
||||
static Layout MakeSparseLayout(int64 max_sparse_elements);
|
||||
|
||||
// Returns default layout for the given shape.
|
||||
static Layout GetDefaultLayoutForShape(const Shape& shape);
|
||||
|
||||
@ -109,17 +105,6 @@ class LayoutUtil {
|
||||
// more minor, and so on until dimension N-1 which is the minor.
|
||||
static bool IsMonotonicWithDim0Major(const Layout& layout);
|
||||
|
||||
// Returns whether the given Shape is an array (i.e. not a tuple) and has a
|
||||
// sparse format layout.
|
||||
static bool IsSparseArray(const Shape& shape);
|
||||
|
||||
// Returns whether the given Layout has a sparse format.
|
||||
static bool IsSparse(const Layout& layout);
|
||||
|
||||
// Returns the maximum number of elements that can be stored in a sparse
|
||||
// layout.
|
||||
static int64 MaxSparseElements(const Layout& layout);
|
||||
|
||||
// Returns whether the given shape has a layout. For tuple shapes, true is
|
||||
// returned only if all elements have layouts.
|
||||
static bool HasLayout(const Shape& shape);
|
||||
|
@ -33,14 +33,6 @@ class LayoutUtilTest : public ::testing::Test {
|
||||
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
|
||||
return shape;
|
||||
}
|
||||
|
||||
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions,
|
||||
int64 max_sparse_elements) {
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
return shape;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LayoutUtilTest, TupleLayoutComparison) {
|
||||
@ -92,29 +84,6 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) {
|
||||
EXPECT_FALSE(dst.has_layout());
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutSparse) {
|
||||
Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2);
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
|
||||
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
|
||||
// Should work if destination has no layout.
|
||||
dst.clear_layout();
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
|
||||
// If source is cleared, then destination should be cleared.
|
||||
src.clear_layout();
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_TRUE(dst.has_layout());
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_FALSE(dst.has_layout());
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutTuple) {
|
||||
Shape src = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
|
||||
@ -134,25 +103,6 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) {
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) {
|
||||
Shape src = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithSparseLayout(F32, {2, 3}, 4),
|
||||
MakeShapeWithSparseLayout(F32, {42, 123}, 4),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {}, {}),
|
||||
MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})});
|
||||
Shape dst = ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
|
||||
MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{MakeShapeWithLayout(F32, {}, {}),
|
||||
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
|
||||
|
||||
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
|
||||
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
|
||||
@ -160,13 +110,6 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) {
|
||||
Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6);
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
|
||||
ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
|
||||
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
|
||||
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
|
||||
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
|
||||
@ -176,15 +119,6 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
|
||||
::testing::ContainsRegex("cannot copy layout from shape"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) {
|
||||
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
|
||||
Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4);
|
||||
auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::ContainsRegex("cannot copy layout from shape"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
|
||||
Shape src =
|
||||
ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
|
||||
|
@ -80,7 +80,7 @@ bool LiteralProtoHasValues(const LiteralProto& proto) {
|
||||
proto.c64s_size() || proto.c128s_size() ||
|
||||
proto.tuple_literals_size() || !proto.f16s().empty() ||
|
||||
!proto.bf16s().empty() || !proto.u16s().empty() ||
|
||||
!proto.s16s().empty() || proto.sparse_indices_size();
|
||||
!proto.s16s().empty();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -135,21 +135,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
|
||||
// Literals can be used as DMA targets, which can require alignment. We
|
||||
// force a 16-byte minimum alignment.
|
||||
constexpr int kMinimumAlignment = 16;
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
// For sparse arrays, the buffer must be of the size of the maximum
|
||||
// number of sparse elements possible.
|
||||
const int64 max_sparse_elements =
|
||||
LayoutUtil::MaxSparseElements(shape.layout());
|
||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||
max_sparse_elements *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()),
|
||||
kMinimumAlignment)));
|
||||
piece->set_sparse_indices(
|
||||
new SparseIndexArray(max_sparse_elements, shape.rank()));
|
||||
} else {
|
||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||
piece->size_bytes(), kMinimumAlignment)));
|
||||
}
|
||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||
piece->size_bytes(), kMinimumAlignment)));
|
||||
}
|
||||
} else {
|
||||
// If the shape is neither an array nor tuple, then it must be
|
||||
@ -181,7 +168,6 @@ void Literal::DeallocateBuffers() {
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
if (piece->buffer() != nullptr) {
|
||||
tensorflow::port::AlignedFree(piece->buffer());
|
||||
delete piece->sparse_indices();
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -211,16 +197,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
|
||||
return literal;
|
||||
}
|
||||
|
||||
const SparseIndexArray* LiteralBase::sparse_indices(
|
||||
const ShapeIndex& shape_index) const {
|
||||
return piece(shape_index).sparse_indices();
|
||||
}
|
||||
|
||||
SparseIndexArray* MutableLiteralBase::sparse_indices(
|
||||
const ShapeIndex& shape_index) {
|
||||
return piece(shape_index).sparse_indices();
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
Status MutableLiteralBase::CopySliceFromInternal(
|
||||
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
||||
@ -373,12 +349,9 @@ std::vector<Literal> Literal::DecomposeTuple() {
|
||||
}
|
||||
Piece& src_piece = piece(src_index);
|
||||
|
||||
// Move the respective buffer and sparse indices over to the element
|
||||
// Literal.
|
||||
// Move the respective buffer over to the element Literal.
|
||||
dest_piece->set_buffer(src_piece.buffer());
|
||||
src_piece.set_buffer(nullptr);
|
||||
dest_piece->set_sparse_indices(src_piece.sparse_indices());
|
||||
src_piece.set_sparse_indices(nullptr);
|
||||
});
|
||||
}
|
||||
// Set this literal to be nil-shaped.
|
||||
@ -512,8 +485,6 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
||||
Piece& dest_piece = piece(dest_index);
|
||||
tensorflow::port::AlignedFree(dest_piece.buffer());
|
||||
dest_piece.set_buffer(src_piece.buffer());
|
||||
delete dest_piece.sparse_indices();
|
||||
dest_piece.set_sparse_indices(src_piece.sparse_indices());
|
||||
});
|
||||
|
||||
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
|
||||
@ -854,66 +825,6 @@ string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
|
||||
}
|
||||
}
|
||||
|
||||
string LiteralBase::GetSparseElementAsString(
|
||||
int64 sparse_element_number, const ShapeIndex& shape_index) const {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
|
||||
CHECK(LayoutUtil::IsSparseArray(subshape));
|
||||
switch (subshape.element_type()) {
|
||||
case PRED:
|
||||
return GetSparseElement<bool>(sparse_element_number, shape_index)
|
||||
? "true"
|
||||
: "false";
|
||||
case S8:
|
||||
return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
|
||||
case S16:
|
||||
return StrCat(
|
||||
GetSparseElement<int16>(sparse_element_number, shape_index));
|
||||
case S32:
|
||||
return StrCat(
|
||||
GetSparseElement<int32>(sparse_element_number, shape_index));
|
||||
case S64:
|
||||
return StrCat(
|
||||
GetSparseElement<int64>(sparse_element_number, shape_index));
|
||||
case U8:
|
||||
return StrCat(
|
||||
GetSparseElement<uint8>(sparse_element_number, shape_index));
|
||||
case U16:
|
||||
return StrCat(
|
||||
GetSparseElement<uint16>(sparse_element_number, shape_index));
|
||||
case U32:
|
||||
return StrCat(
|
||||
GetSparseElement<uint32>(sparse_element_number, shape_index));
|
||||
case U64:
|
||||
return StrCat(
|
||||
GetSparseElement<uint64>(sparse_element_number, shape_index));
|
||||
case F16:
|
||||
return StrCat(static_cast<float>(
|
||||
GetSparseElement<half>(sparse_element_number, shape_index)));
|
||||
case F32:
|
||||
return StrCat(
|
||||
GetSparseElement<float>(sparse_element_number, shape_index));
|
||||
case BF16:
|
||||
return StrCat(static_cast<float>(
|
||||
GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
|
||||
case F64:
|
||||
return StrCat(
|
||||
GetSparseElement<double>(sparse_element_number, shape_index));
|
||||
case C64: {
|
||||
complex64 c =
|
||||
GetSparseElement<complex64>(sparse_element_number, shape_index);
|
||||
return StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
case C128: {
|
||||
complex128 c =
|
||||
GetSparseElement<complex128>(sparse_element_number, shape_index);
|
||||
return StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Invalid element type for sparse arrays: "
|
||||
<< PrimitiveType_Name(subshape.element_type());
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<int64> LiteralBase::GetIntegralAsS64(
|
||||
absl::Span<const int64> multi_index) const {
|
||||
CHECK(LayoutUtil::IsDenseArray(shape()));
|
||||
@ -1047,81 +958,6 @@ Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
absl::Span<const int64> LiteralBase::GetSparseIndex(
|
||||
int64 sparse_element_number, const ShapeIndex& shape_index) const {
|
||||
const Piece& p = piece(shape_index);
|
||||
CHECK_GE(sparse_element_number, 0);
|
||||
CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
|
||||
return p.sparse_indices()->At(sparse_element_number);
|
||||
}
|
||||
|
||||
void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
|
||||
piece(shape_index).SortSparseElements();
|
||||
}
|
||||
|
||||
void LiteralBase::Piece::SortSparseElements() {
|
||||
switch (subshape().element_type()) {
|
||||
case PRED:
|
||||
SortSparseElementsInternal<bool>();
|
||||
break;
|
||||
case S8:
|
||||
SortSparseElementsInternal<int8>();
|
||||
break;
|
||||
case U8:
|
||||
SortSparseElementsInternal<uint8>();
|
||||
break;
|
||||
case S16:
|
||||
SortSparseElementsInternal<int16>();
|
||||
break;
|
||||
case U16:
|
||||
SortSparseElementsInternal<uint16>();
|
||||
break;
|
||||
case S32:
|
||||
SortSparseElementsInternal<int32>();
|
||||
break;
|
||||
case U32:
|
||||
SortSparseElementsInternal<uint32>();
|
||||
break;
|
||||
case S64:
|
||||
SortSparseElementsInternal<int64>();
|
||||
break;
|
||||
case U64:
|
||||
SortSparseElementsInternal<uint64>();
|
||||
break;
|
||||
case F32:
|
||||
SortSparseElementsInternal<float>();
|
||||
break;
|
||||
case F64:
|
||||
SortSparseElementsInternal<double>();
|
||||
break;
|
||||
case C64:
|
||||
SortSparseElementsInternal<complex64>();
|
||||
break;
|
||||
case C128:
|
||||
SortSparseElementsInternal<complex128>();
|
||||
break;
|
||||
case F16:
|
||||
SortSparseElementsInternal<half>();
|
||||
break;
|
||||
case BF16:
|
||||
SortSparseElementsInternal<bfloat16>();
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Element type not valid for sparse array: "
|
||||
<< PrimitiveType_Name(subshape().element_type());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void LiteralBase::Piece::SortSparseElementsInternal() {
|
||||
CHECK(LayoutUtil::IsSparseArray(subshape()));
|
||||
int64 num_elements = sparse_indices()->index_count();
|
||||
auto values = data<NativeT>();
|
||||
CHECK_LE(num_elements, values.size());
|
||||
sparse_indices()->SortWithValues(
|
||||
absl::Span<NativeT>(values.data(), num_elements));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
string ShapeToString(bool print_layout, const Shape& shape) {
|
||||
@ -1151,32 +987,6 @@ void TupleToStringHelper(const LiteralBase& literal,
|
||||
pieces->push_back("\n)");
|
||||
}
|
||||
|
||||
void SparseArrayToStringHelper(const LiteralBase& literal,
|
||||
const Shape& subshape, bool print_shape,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
if (print_shape) {
|
||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||
}
|
||||
pieces->push_back("{");
|
||||
int64 rank = subshape.rank();
|
||||
int64 num_elements = literal.sparse_element_count();
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
if (i > 0) {
|
||||
pieces->push_back(", ");
|
||||
}
|
||||
if (rank == 1) {
|
||||
pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
|
||||
pieces->push_back(": ");
|
||||
} else {
|
||||
pieces->push_back("[");
|
||||
pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", "));
|
||||
pieces->push_back("]: ");
|
||||
}
|
||||
pieces->push_back(literal.GetSparseElementAsString(i));
|
||||
}
|
||||
pieces->push_back("}");
|
||||
}
|
||||
|
||||
void DenseArrayToStringHelper(const LiteralBase& literal,
|
||||
const ShapeIndex& shape_index, bool print_shape,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
@ -1261,9 +1071,6 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
pieces);
|
||||
} else if (subshape.IsToken()) {
|
||||
pieces->push_back("token");
|
||||
} else if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
SparseArrayToStringHelper(literal, subshape, print_shape, print_layout,
|
||||
pieces);
|
||||
} else {
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape));
|
||||
DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
|
||||
@ -1273,11 +1080,6 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
|
||||
} // namespace
|
||||
|
||||
int64 LiteralBase::sparse_element_count() const {
|
||||
CHECK(LayoutUtil::IsSparseArray(shape()));
|
||||
return sparse_indices()->index_count();
|
||||
}
|
||||
|
||||
string LiteralBase::ToString() const {
|
||||
std::vector<string> pieces;
|
||||
CHECK(LayoutUtil::HasLayout(this->shape()));
|
||||
@ -2053,22 +1855,6 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(shape));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
|
||||
|
||||
if (LayoutUtil::IsSparseArray(subshape())) {
|
||||
// Compute the number of elements (indices) in the sparse shape and reserve
|
||||
// the necessary space in spare_indices.
|
||||
TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse";
|
||||
TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0)
|
||||
<< "Unexpected number of indices in proto ("
|
||||
<< proto.sparse_indices_size() << ") for shape of rank "
|
||||
<< subshape().rank();
|
||||
const int64 index_count = proto.sparse_indices_size() / subshape().rank();
|
||||
sparse_indices()->Resize(index_count);
|
||||
|
||||
// Copy the indices from the proto into the SparseIndexArray object.
|
||||
TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
|
||||
proto.sparse_indices()));
|
||||
}
|
||||
|
||||
switch (subshape().element_type()) {
|
||||
case PRED:
|
||||
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
|
||||
@ -2175,11 +1961,6 @@ LiteralProto LiteralBase::ToProto() const {
|
||||
piece.WriteToProto(proto_piece);
|
||||
});
|
||||
|
||||
if (LayoutUtil::IsSparseArray(shape())) {
|
||||
CopyToRepeatedField(proto.mutable_sparse_indices(),
|
||||
sparse_indices()->data());
|
||||
}
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -2295,12 +2076,6 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
|
||||
|
||||
MutableBorrowingLiteral::~MutableBorrowingLiteral() {
|
||||
if (root_piece_ != nullptr) {
|
||||
root_piece_->ForEachMutableSubpiece(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
if (piece->buffer() != nullptr) {
|
||||
delete piece->sparse_indices();
|
||||
}
|
||||
});
|
||||
delete root_piece_;
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -77,11 +76,6 @@ class LiteralBase {
|
||||
template <typename NativeT>
|
||||
absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Returns a const pointer to the sparse index array. Returns nullptr if the
|
||||
// literal is not a sparse array.
|
||||
const SparseIndexArray* sparse_indices(
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Returns a const pointer to (or size of) the underlying buffer holding the
|
||||
// array at the given shape index. CHECKs if the subshape of the literal at
|
||||
// the given ShapeIndex is not array.
|
||||
@ -126,10 +120,6 @@ class LiteralBase {
|
||||
// into text.
|
||||
string GetAsString(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
// As GetSparseElement(), but determines the correct type and converts the
|
||||
// value into text.
|
||||
string GetSparseElementAsString(int64 sparse_element_number,
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Return whether the value at the specified index is equal to the provided
|
||||
// generic `value` (T must be an arithmetic type).
|
||||
@ -172,21 +162,6 @@ class LiteralBase {
|
||||
absl::optional<complex128> GetAsComplex128(
|
||||
absl::Span<const int64> multi_index) const;
|
||||
|
||||
// Returns the multi-index of the element in a sparse literal at the given
|
||||
// sparse element number. The sparse element number is the position with in
|
||||
// the sparse array's list of (index, value) pairs, and is checked against the
|
||||
// total number of (index, value) pairs in the sparse array.
|
||||
absl::Span<const int64> GetSparseIndex(
|
||||
int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Returns the value of the element in a sparse literal at the given sparse
|
||||
// element number. The sparse element number is the position with in the
|
||||
// sparse array's list of (index, value) pairs, and is checked against the
|
||||
// total number of (index, value) pairs in the sparse array.
|
||||
template <typename NativeT>
|
||||
NativeT GetSparseElement(int64 sparse_element_number,
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Invokes the "per cell" callback for each element in the provided
|
||||
// literal with the element's indices and a string representation of
|
||||
// the element's value.
|
||||
@ -259,13 +234,7 @@ class LiteralBase {
|
||||
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
|
||||
}
|
||||
|
||||
// Returns the count of the elements in the sparse array at the given shape
|
||||
// index in this literal, which will be no larger than
|
||||
// LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
|
||||
int64 sparse_element_count() const;
|
||||
|
||||
// Compute a hash for this literal. This literal must not be a sparse tensor
|
||||
// or a tuple containing a sparse tensor.
|
||||
// Compute a hash for this literal.
|
||||
size_t Hash() const;
|
||||
|
||||
// Converts this literal to the given shape. Returns an error is the
|
||||
@ -385,14 +354,6 @@ class LiteralBase {
|
||||
char* buffer() const { return buffer_; }
|
||||
void set_buffer(char* buffer) { buffer_ = buffer; }
|
||||
|
||||
// The array of multi-indices that provide the locations of non-zero
|
||||
// elements in a sparse array. Only used if
|
||||
// LayoutUtil::IsSparseArray(shape()) is true.
|
||||
SparseIndexArray* sparse_indices() const { return sparse_indices_; }
|
||||
void set_sparse_indices(SparseIndexArray* sparse_indices) {
|
||||
sparse_indices_ = sparse_indices;
|
||||
}
|
||||
|
||||
// Gets or sets the subshape of this piece. This reference points to a
|
||||
// subshape within the shape in the containing Literal (Literal::shape_).
|
||||
const Shape& subshape() const { return *subshape_; }
|
||||
@ -402,13 +363,7 @@ class LiteralBase {
|
||||
int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
|
||||
|
||||
// Returns the number of elements in this piece's array.
|
||||
int64 element_count() const {
|
||||
// If this is a sparse array, use the number of elements represented by
|
||||
// the indices in the associated SparseIndexArray.
|
||||
return LayoutUtil::IsSparseArray(subshape())
|
||||
? sparse_indices()->index_count()
|
||||
: ShapeUtil::ElementsIn(subshape());
|
||||
}
|
||||
int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
|
||||
|
||||
// Returns the child piece at 'index' of this piece.
|
||||
Piece& child(int64 index) { return children_[index]; }
|
||||
@ -489,9 +444,6 @@ class LiteralBase {
|
||||
// piece must be equal (not just compatible) to the shape of the proto.
|
||||
Status CopyFromProto(const LiteralProto& proto);
|
||||
|
||||
// Sorts the elements in a sparse array.
|
||||
void SortSparseElements();
|
||||
|
||||
private:
|
||||
// Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
|
||||
// The first non-OK (or non-true) value is returned by the function.
|
||||
@ -541,17 +493,9 @@ class LiteralBase {
|
||||
bool EqualElementsInternal(const Piece& other,
|
||||
std::vector<int64>* multi_index) const;
|
||||
|
||||
// Helper for SortSparseElements that has the element type as a template
|
||||
// parameter.
|
||||
template <typename NativeT>
|
||||
void SortSparseElementsInternal();
|
||||
|
||||
// For array-shaped pieces, this is the buffer holding the literal data.
|
||||
char* buffer_ = nullptr;
|
||||
|
||||
// For sparse arrays, this is the array of indices.
|
||||
SparseIndexArray* sparse_indices_ = nullptr;
|
||||
|
||||
// The shape of piece. This points into the shape of the containing Literal
|
||||
// (Literal::shape_).
|
||||
const Shape* subshape_ = nullptr;
|
||||
@ -598,10 +542,6 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// Unhide const method from parent class.
|
||||
using LiteralBase::data;
|
||||
|
||||
// Returns a pointer to the sparse index array. Returns nullptr if the literal
|
||||
// is not a sparse array.
|
||||
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
|
||||
|
||||
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
|
||||
// mutate the shape as this can produce malformed Literals.
|
||||
Shape* mutable_shape_do_not_use() { return shape_.get(); }
|
||||
@ -613,16 +553,6 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// Unhide const method from parent class.
|
||||
using LiteralBase::untyped_data;
|
||||
|
||||
// Populates a literal with a sparse layout with the given indices and values.
|
||||
// Each index in the indices array is CHECKed against the dimensions in the
|
||||
// literal's shape. If sort is true, then the indices and values will be
|
||||
// sorted. If sort is false, then the indices and values are assumed to
|
||||
// already be in sorted order. See CreateSparse for an example of how data
|
||||
// are populated.
|
||||
template <typename NativeT>
|
||||
void PopulateSparse(SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values, bool sort = true);
|
||||
|
||||
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
||||
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
||||
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
|
||||
@ -661,16 +591,6 @@ class MutableLiteralBase : public LiteralBase {
|
||||
template <typename NativeT>
|
||||
void Set(absl::Span<const int64> multi_index, NativeT value);
|
||||
|
||||
// Appends the given element to the literal. If the elements are not appended
|
||||
// in sorted order, then SortSparseElements should be called before calling
|
||||
// other methods. This literal must have a sparse layout.
|
||||
template <typename NativeT>
|
||||
void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
// Sorts the elements in a sparse array.
|
||||
void SortSparseElements(const ShapeIndex& shape_index = {});
|
||||
|
||||
// As Set(), but truncates `value` to the literal element type before storing.
|
||||
// This literal must be an array.
|
||||
Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
|
||||
@ -988,34 +908,6 @@ NativeT LiteralBase::GetFirstElement() const {
|
||||
return data<NativeT>().at(0);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
|
||||
const ShapeIndex& shape_index) const {
|
||||
CHECK(
|
||||
LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
|
||||
return data<NativeT>(shape_index)[sparse_element_number];
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void MutableLiteralBase::AppendSparseElement(
|
||||
absl::Span<const int64> multi_index, NativeT value,
|
||||
const ShapeIndex& shape_index) {
|
||||
Piece& p = piece(shape_index);
|
||||
const Shape& subshape = p.subshape();
|
||||
CHECK(LayoutUtil::IsSparseArray(subshape));
|
||||
int64 rank = subshape.rank();
|
||||
CHECK_EQ(multi_index.size(), rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
CHECK_GE(multi_index[i], 0);
|
||||
CHECK_LT(multi_index[i], subshape.dimensions(i));
|
||||
}
|
||||
int64 last_element = p.sparse_indices()->index_count();
|
||||
CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
|
||||
p.sparse_indices()->Append(multi_index);
|
||||
CHECK_LT(last_element, p.data<NativeT>().size());
|
||||
p.data<NativeT>()[last_element] = value;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void LiteralBase::EachCell(
|
||||
std::function<void(absl::Span<const int64> indices, NativeT value)>
|
||||
@ -1094,31 +986,6 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values,
|
||||
bool sort) {
|
||||
CHECK(LayoutUtil::IsSparseArray(shape()));
|
||||
int rank = shape().rank();
|
||||
CHECK_EQ(indices.rank(), rank);
|
||||
int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
|
||||
CHECK_LE(indices.max_indices(), max_elements);
|
||||
int64 num_elements = values.size();
|
||||
CHECK_LE(num_elements, max_elements);
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
auto root_data = root_piece().data<NativeT>();
|
||||
// Piece::data() returns a Span of size equal to the number of indices
|
||||
// in the SparseIndexArray. So there is no need to adjust the size of the data
|
||||
// here. It is enough to just copy the incoming values into the data buffer.
|
||||
std::copy(values.begin(), values.end(), root_data.begin());
|
||||
*this->root_piece().sparse_indices() = std::move(indices);
|
||||
if (sort) {
|
||||
auto root_data = this->root_piece().data<NativeT>();
|
||||
this->root_piece().sparse_indices()->SortWithValues(root_data);
|
||||
}
|
||||
DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
|
||||
}
|
||||
|
||||
template <typename NativeT, typename FnType>
|
||||
Status MutableLiteralBase::PopulateInternal(const FnType& generator,
|
||||
bool parallel) {
|
||||
|
@ -252,42 +252,6 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
|
||||
EXPECT_EQ(expected, result);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, CreateSparse) {
|
||||
std::vector<int64> dimensions = {8, 8, 8};
|
||||
Array2D<int64> indices = {
|
||||
{3, 4, 5},
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
{3, 5, 6},
|
||||
};
|
||||
std::vector<int64> values = {7, 8, 9, 10};
|
||||
auto literal = LiteralUtil::CreateSparse<int64>(
|
||||
dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
|
||||
|
||||
Array2D<int64> expected_indices = {
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
{3, 4, 5},
|
||||
{3, 5, 6},
|
||||
};
|
||||
std::vector<int64> expected_values = {8, 9, 7, 10};
|
||||
|
||||
EXPECT_EQ(literal.sparse_indices()->data(),
|
||||
absl::Span<const int64>(expected_indices.data(),
|
||||
expected_indices.num_elements()));
|
||||
EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
|
||||
|
||||
// Serialize then deserialize and verify the resulting literal.
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
|
||||
Literal::CreateFromProto(literal.ToProto()));
|
||||
|
||||
EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
|
||||
absl::Span<const int64>(expected_indices.data(),
|
||||
expected_indices.num_elements()));
|
||||
EXPECT_EQ(literal_from_proto.data<int64>(),
|
||||
absl::Span<const int64>(expected_values));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
||||
// clang-format off
|
||||
auto literal = LiteralUtil::CreateR4Projected<float>({
|
||||
@ -1978,43 +1942,6 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
|
||||
EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, SortSparseElements) {
|
||||
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
|
||||
SparseIndexArray(10, 3), {});
|
||||
literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
|
||||
literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
|
||||
literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
|
||||
literal.SortSparseElements();
|
||||
EXPECT_EQ(literal.ToString(),
|
||||
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, GetSparseElementAsString) {
|
||||
std::vector<int64> dimensions = {10, 10, 10};
|
||||
SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
|
||||
|
||||
EXPECT_EQ(
|
||||
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
|
||||
.GetSparseElementAsString(1),
|
||||
"false");
|
||||
EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
|
||||
.GetSparseElementAsString(1),
|
||||
absl::StrCat(int64{2}));
|
||||
EXPECT_EQ(
|
||||
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
|
||||
.GetSparseElementAsString(1),
|
||||
absl::StrCat(double{2.0}));
|
||||
EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
|
||||
{half{1.0}, half{2.0}, half{3.0}})
|
||||
.GetSparseElementAsString(1),
|
||||
absl::StrCat(static_cast<float>(half{2.0})));
|
||||
EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
|
||||
dimensions, indices,
|
||||
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
|
||||
.GetSparseElementAsString(1),
|
||||
absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
|
||||
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -38,7 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -102,46 +101,6 @@ class LiteralUtil {
|
||||
values,
|
||||
const Layout& layout);
|
||||
|
||||
// Creates a literal with a sparse layout and the given indices and values.
|
||||
// The shape is initialized from the given dimensions. The minor dimension of
|
||||
// the indices array must equal the rank of the shape (i.e. size of the
|
||||
// dimensions array). The major dimension of the indices array must equal the
|
||||
// number of elements in the values array. The maximum number of elements in
|
||||
// the array is taken from the max_indices() value of the index array.
|
||||
//
|
||||
// XLA assumes that sparse literals are in sorted order for all operations. If
|
||||
// the `sort` argument is true, then the indices and values will be sorted
|
||||
// while copying them into the literal. If you have ensured that the indices
|
||||
// and values are already sorted, then you may set the `sort` argument to
|
||||
// false to skip the sorting step.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// CreateSparse(
|
||||
// {12, 12, 12},
|
||||
// SparseIndexArray(10, 3,
|
||||
// Array2D{
|
||||
// {0, 1, 2},
|
||||
// {3, 4, 5},
|
||||
// {6, 7, 8},
|
||||
// {9, 10, 11},
|
||||
// }),
|
||||
// {1.0, 2.0 3.0, 4.0})
|
||||
//
|
||||
// This creates an array with shape F64[12,12,12]sparse{10}, that has the
|
||||
// following non-zero values:
|
||||
//
|
||||
// [0, 1, 2]: 1.0
|
||||
// [3, 4, 5]: 2.0
|
||||
// [6, 7, 8]: 3.0
|
||||
// [9, 10, 11]: 4.0
|
||||
//
|
||||
template <typename NativeT>
|
||||
static Literal CreateSparse(absl::Span<const int64> dimensions,
|
||||
SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values,
|
||||
bool sort = true);
|
||||
|
||||
// Creates a scalar literal value zero of the given primitive type.
|
||||
static Literal Zero(PrimitiveType primitive_type);
|
||||
// Creates a scalar literal value one of the given primitive type.
|
||||
@ -417,21 +376,6 @@ template <typename NativeT>
|
||||
return CreateR4FromArray4DWithLayout(tmp, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ Literal LiteralUtil::CreateSparse(
|
||||
absl::Span<const int64> dimensions, SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values, bool sort) {
|
||||
int64 num_elements = values.size();
|
||||
int64 rank = dimensions.size();
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
CHECK_EQ(rank, indices.rank());
|
||||
Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
|
||||
indices.max_indices()));
|
||||
literal.PopulateSparse(indices, values, sort);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ Literal LiteralUtil::CreateR4(
|
||||
std::initializer_list<std::initializer_list<
|
||||
|
@ -77,7 +77,6 @@ cc_library(
|
||||
":buffer_info_util",
|
||||
":conv_canonicalization",
|
||||
":cpu_executable",
|
||||
":cpu_hlo_support_checker",
|
||||
":cpu_instruction_fusion",
|
||||
":cpu_layout_assignment",
|
||||
":cpu_options",
|
||||
@ -960,32 +959,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_hlo_support_checker",
|
||||
srcs = ["cpu_hlo_support_checker.cc"],
|
||||
hdrs = ["cpu_hlo_support_checker.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_hlo_support_checker_test",
|
||||
srcs = ["cpu_hlo_support_checker_test.cc"],
|
||||
deps = [
|
||||
":cpu_hlo_support_checker",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_eigen_tensor_alignment_test",
|
||||
size = "small",
|
||||
|
@ -60,7 +60,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
@ -248,7 +247,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<ZeroSizedHloElimination>();
|
||||
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<CpuHloSupportChecker>();
|
||||
|
||||
pipeline.AddPass<ConditionalToSelect>();
|
||||
pipeline.AddPass<MapInliner>();
|
||||
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) {
|
||||
for (auto* computation : module->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
instruction->shape(),
|
||||
[&instruction](const Shape& subshape, const ShapeIndex&) {
|
||||
if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
return xla::Unimplemented(
|
||||
"CPU backend does not support HLO instruction %s with shape "
|
||||
"containing a sparse layout: %s",
|
||||
instruction->ToString(),
|
||||
ShapeUtil::HumanStringWithLayout(instruction->shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,40 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This pass should run early in the HLO pipeline and checks for HLO constructs
|
||||
// which are not supported by the CPU backend and cannot be removed via HLO
|
||||
// transformations (eg, sparse layouts).
|
||||
class CpuHloSupportChecker : public HloModulePass {
|
||||
public:
|
||||
CpuHloSupportChecker() = default;
|
||||
~CpuHloSupportChecker() override = default;
|
||||
|
||||
absl::string_view name() const override { return "cpu_hlo_support_checker"; }
|
||||
|
||||
// Note: always returns false (no instructions are ever modified by this
|
||||
// pass).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class CpuHloSupportCheckerTest : public HloTestBase {
|
||||
protected:
|
||||
CpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
private:
|
||||
CpuHloSupportChecker checker_;
|
||||
};
|
||||
|
||||
TEST_F(CpuHloSupportCheckerTest, Add) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewVerifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
||||
}
|
||||
|
||||
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
sparse_shape, HloOpcode::kAdd, param0, param1));
|
||||
// Since verifier is reporting sparse layouts as errors, we should
|
||||
// use a regular HloModule instead of VerifiedHloModule to avoid
|
||||
// verifier errors being triggered in the destructor.
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("CPU backend does not support"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -116,29 +116,6 @@ non_tuple
|
||||
| rank2345
|
||||
;
|
||||
rank2345
|
||||
: shape sparse_or_nested_array
|
||||
: nested_array
|
||||
;
|
||||
sparse_or_nested_array
|
||||
: sparse_array
|
||||
| nested_array
|
||||
;
|
||||
sparse_array
|
||||
: '{' sparse_array1 '}'
|
||||
;
|
||||
sparse_array1
|
||||
: sparse_array_item
|
||||
| sparse_array1 ',' sparse_array_item
|
||||
;
|
||||
sparse_array_item
|
||||
: multi_index ':' scalar
|
||||
;
|
||||
multi_index
|
||||
: kInt
|
||||
| '[' multi_index1 ']'
|
||||
;
|
||||
multi_index1
|
||||
: kInt
|
||||
| multi_index1 ',' kInt
|
||||
;
|
||||
|
||||
```
|
||||
|
@ -1093,7 +1093,6 @@ cc_library(
|
||||
":gpu_copy_insertion",
|
||||
":gpu_executable",
|
||||
":gpu_hlo_schedule",
|
||||
":gpu_hlo_support_checker",
|
||||
":gpu_layout_assignment",
|
||||
":gpu_sanitize_constant_names",
|
||||
":gpu_scatter_expander",
|
||||
@ -1416,18 +1415,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_hlo_support_checker",
|
||||
srcs = ["gpu_hlo_support_checker.cc"],
|
||||
hdrs = ["gpu_hlo_support_checker.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_util",
|
||||
srcs = ["stream_executor_util.cc"],
|
||||
@ -1455,20 +1442,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_hlo_support_checker_test",
|
||||
srcs = ["gpu_hlo_support_checker_test.cc"],
|
||||
deps = [
|
||||
":gpu_hlo_support_checker",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "buffer_comparator",
|
||||
srcs = ["buffer_comparator.cc"],
|
||||
|
@ -49,7 +49,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
|
||||
@ -135,7 +134,6 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
pipeline.AddPass<GpuScatterExpander>();
|
||||
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<GpuHloSupportChecker>();
|
||||
|
||||
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) {
|
||||
for (auto* computation : module->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
instruction->shape(),
|
||||
[&instruction](const Shape& subshape, const ShapeIndex&) {
|
||||
if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
return xla::Unimplemented(
|
||||
"GPU backend does not support HLO instruction %s with shape "
|
||||
"containing a sparse layout: %s",
|
||||
instruction->ToString(),
|
||||
ShapeUtil::HumanStringWithLayout(instruction->shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,40 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This pass should run early in the HLO pipeline and checks for HLO constructs
|
||||
// which are not supported by the GPU backend and cannot be removed via HLO
|
||||
// transformations (eg, sparse layouts).
|
||||
class GpuHloSupportChecker : public HloModulePass {
|
||||
public:
|
||||
GpuHloSupportChecker() = default;
|
||||
~GpuHloSupportChecker() override = default;
|
||||
|
||||
absl::string_view name() const override { return "gpu_hlo_support_checker"; }
|
||||
|
||||
// Note: always returns false (no instructions are ever modified by this
|
||||
// pass).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class GpuHloSupportCheckerTest : public HloTestBase {
|
||||
protected:
|
||||
GpuHloSupportChecker& checker() { return checker_; }
|
||||
|
||||
private:
|
||||
GpuHloSupportChecker checker_;
|
||||
};
|
||||
|
||||
TEST_F(GpuHloSupportCheckerTest, Add) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewVerifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(checker().Run(module.get()).status());
|
||||
}
|
||||
|
||||
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
sparse_shape, HloOpcode::kAdd, param0, param1));
|
||||
// Since verifier is reporting sparse layouts as errors, we should
|
||||
// use a regular HloModule instead of VerifiedHloModule to avoid
|
||||
// verifier errors being triggered in the destructor.
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module.get()).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("GPU backend does not support"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -280,7 +280,6 @@ TokKind HloLexer::LexIdentifier() {
|
||||
KEYWORD(ROOT);
|
||||
KEYWORD(maximal);
|
||||
KEYWORD(replicated);
|
||||
KEYWORD(sparse);
|
||||
|
||||
#undef KEYWORD
|
||||
|
||||
@ -496,8 +495,6 @@ string TokKindToString(TokKind kind) {
|
||||
return "kw_inf";
|
||||
case TokKind::kNegInf:
|
||||
return "kNegInf";
|
||||
case TokKind::kw_sparse:
|
||||
return "kw_sparse";
|
||||
case TokKind::kPrimitiveType:
|
||||
return "kPrimitiveType";
|
||||
case TokKind::kName:
|
||||
|
@ -63,7 +63,6 @@ enum class TokKind {
|
||||
kw_replicated,
|
||||
kw_nan,
|
||||
kw_inf,
|
||||
kw_sparse,
|
||||
|
||||
kNegInf, // -inf
|
||||
|
||||
|
@ -72,10 +72,6 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
|
||||
return schedule;
|
||||
}
|
||||
|
||||
// Some functions accept either a linear index or a multi-dimensional index
|
||||
// (used for indexing into sparse literals).
|
||||
using LinearOrMultiIndex = absl::variant<int64, absl::Span<const int64>>;
|
||||
|
||||
// Parser for the HloModule::ToString() format text.
|
||||
class HloParserImpl : public HloParser {
|
||||
public:
|
||||
@ -137,24 +133,21 @@ class HloParserImpl : public HloParser {
|
||||
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseDenseLiteral(Literal* literal, const Shape& shape);
|
||||
bool ParseSparseLiteral(Literal* literal, const Shape& shape);
|
||||
|
||||
// Sets the sub-value of literal at the given linear or sparse index to the
|
||||
// given value. If the literal is dense, it myst have the default layout.
|
||||
// Sets the sub-value of literal at the given linear index to the
|
||||
// given value. If the literal is dense, it must have the default layout.
|
||||
//
|
||||
// `loc` should be the source location of the value.
|
||||
bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index,
|
||||
bool SetValueInLiteral(LocTy loc, int64 value, int64 index, Literal* literal);
|
||||
bool SetValueInLiteral(LocTy loc, double value, int64 index,
|
||||
Literal* literal);
|
||||
bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index,
|
||||
bool SetValueInLiteral(LocTy loc, bool value, int64 index, Literal* literal);
|
||||
bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64 index,
|
||||
Literal* literal);
|
||||
bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index,
|
||||
Literal* literal);
|
||||
bool SetValueInLiteral(LocTy loc, std::complex<double> value,
|
||||
LinearOrMultiIndex index, Literal* literal);
|
||||
// `loc` should be the source location of the value.
|
||||
template <typename LiteralNativeT, typename ParsedElemT>
|
||||
bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
|
||||
LinearOrMultiIndex index, Literal* literal);
|
||||
bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64 index,
|
||||
Literal* literal);
|
||||
|
||||
// Checks whether the given value is within the range of LiteralNativeT.
|
||||
// `loc` should be the source location of the value.
|
||||
@ -2125,8 +2118,7 @@ bool HloParserImpl::ParseInstructionNames(
|
||||
"expects '}' at the end of instruction name list");
|
||||
}
|
||||
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value,
|
||||
LinearOrMultiIndex index,
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, int64 index,
|
||||
Literal* literal) {
|
||||
const Shape& shape = literal->shape();
|
||||
switch (shape.element_type()) {
|
||||
@ -2160,8 +2152,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value,
|
||||
}
|
||||
}
|
||||
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, double value,
|
||||
LinearOrMultiIndex index,
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64 index,
|
||||
Literal* literal) {
|
||||
const Shape& shape = literal->shape();
|
||||
switch (shape.element_type()) {
|
||||
@ -2180,8 +2171,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, double value,
|
||||
}
|
||||
}
|
||||
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value,
|
||||
LinearOrMultiIndex index,
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64 index,
|
||||
Literal* literal) {
|
||||
const Shape& shape = literal->shape();
|
||||
switch (shape.element_type()) {
|
||||
@ -2194,8 +2184,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value,
|
||||
}
|
||||
|
||||
bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
|
||||
LinearOrMultiIndex index,
|
||||
Literal* literal) {
|
||||
int64 index, Literal* literal) {
|
||||
const Shape& shape = literal->shape();
|
||||
switch (shape.element_type()) {
|
||||
case C64:
|
||||
@ -2221,54 +2210,21 @@ std::string StringifyValue(std::complex<double> val) {
|
||||
|
||||
template <typename LiteralNativeT, typename ParsedElemT>
|
||||
bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
|
||||
LinearOrMultiIndex index,
|
||||
Literal* literal) {
|
||||
int64 index, Literal* literal) {
|
||||
if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the index is in range and assign into the literal
|
||||
if (auto* linear_index = absl::get_if<int64>(&index)) {
|
||||
if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
|
||||
return Error(loc, StrCat("trys to set value ", StringifyValue(value),
|
||||
" to a literal in shape ",
|
||||
ShapeUtil::HumanString(literal->shape()),
|
||||
" at linear index ", *linear_index,
|
||||
", but the index is out of range"));
|
||||
}
|
||||
literal->data<LiteralNativeT>().at(*linear_index) =
|
||||
static_cast<LiteralNativeT>(value);
|
||||
} else {
|
||||
auto* multi_index = absl::get_if<absl::Span<const int64>>(&index);
|
||||
CHECK(multi_index != nullptr);
|
||||
|
||||
auto invalid_idx = [&](std::string msg) {
|
||||
return Error(loc, StrFormat("Invalid sparse index [%s]. %s",
|
||||
absl::StrJoin(*multi_index, ", "), msg));
|
||||
};
|
||||
|
||||
const auto& shape = literal->shape();
|
||||
if (shape.rank() != multi_index->size()) {
|
||||
return invalid_idx(
|
||||
StrFormat("Has rank %d, but constant has shape %s, which has rank %d",
|
||||
multi_index->size(), shape.ToString(), shape.rank()));
|
||||
}
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
auto idx = (*multi_index)[i];
|
||||
if (idx < 0) {
|
||||
return invalid_idx(StrFormat(
|
||||
"Sub-index value at %d, namely %d, cannot be negative.", i, idx));
|
||||
}
|
||||
if (idx >= shape.dimensions(i)) {
|
||||
return invalid_idx(
|
||||
StrFormat("Sub-index at %d, namely %d, doesn't fit within shape "
|
||||
"dimension %d in %s",
|
||||
i, idx, shape.dimensions(i), shape.ToString()));
|
||||
}
|
||||
}
|
||||
literal->AppendSparseElement(*multi_index,
|
||||
static_cast<LiteralNativeT>(value));
|
||||
if (index >= ShapeUtil::ElementsIn(literal->shape())) {
|
||||
return Error(loc, StrCat("trys to set value ", StringifyValue(value),
|
||||
" to a literal in shape ",
|
||||
ShapeUtil::HumanString(literal->shape()),
|
||||
" at linear index ", index,
|
||||
", but the index is out of range"));
|
||||
}
|
||||
literal->data<LiteralNativeT>().at(index) =
|
||||
static_cast<LiteralNativeT>(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -2314,12 +2270,8 @@ bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
// non_tuple
|
||||
// ::= rank01
|
||||
// ::= rank2345
|
||||
// rank2345 ::= shape sparse_or_nested_array
|
||||
// rank2345 ::= shape nested_array
|
||||
bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
return ParseSparseLiteral(literal, shape);
|
||||
}
|
||||
|
||||
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
|
||||
return ParseDenseLiteral(literal, shape);
|
||||
}
|
||||
@ -2500,98 +2452,6 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseSparseLiteral(Literal* literal, const Shape& shape) {
|
||||
*literal = Literal(shape);
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expects '{' at the beginning of a sparse literal")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (;;) {
|
||||
if (lexer_.GetKind() == TokKind::kRbrace) {
|
||||
lexer_.Lex();
|
||||
break;
|
||||
}
|
||||
|
||||
std::vector<int64> index;
|
||||
if (lexer_.GetKind() == TokKind::kInt) {
|
||||
int64 single_index = lexer_.GetInt64Val();
|
||||
lexer_.Lex();
|
||||
index.push_back(single_index);
|
||||
} else {
|
||||
if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
|
||||
&index)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!ParseToken(TokKind::kColon,
|
||||
"expects ':' after after the sparse array index and before "
|
||||
"the sparse array value")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
LocTy value_loc = lexer_.GetLoc();
|
||||
if (lexer_.GetKind() == TokKind::kw_true ||
|
||||
lexer_.GetKind() == TokKind::kw_false) {
|
||||
bool value = lexer_.GetKind() == TokKind::kw_true;
|
||||
if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) {
|
||||
return false;
|
||||
}
|
||||
lexer_.Lex();
|
||||
} else if (primitive_util::IsIntegralType(shape.element_type())) {
|
||||
int64 value;
|
||||
if (!ParseInt64(&value)) {
|
||||
return Error(value_loc,
|
||||
StrCat("expects integer for primitive type: ",
|
||||
PrimitiveType_Name(shape.element_type())));
|
||||
}
|
||||
if (!SetValueInLiteral(value_loc, value, index, literal)) {
|
||||
return false;
|
||||
}
|
||||
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
|
||||
double value;
|
||||
if (!ParseDouble(&value)) {
|
||||
return Error(value_loc,
|
||||
StrCat("expects floating point value for primitive type: ",
|
||||
PrimitiveType_Name(shape.element_type())));
|
||||
}
|
||||
if (!SetValueInLiteral(value_loc, value, index, literal)) {
|
||||
return false;
|
||||
}
|
||||
} else if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
std::complex<double> value;
|
||||
if (!ParseComplex(&value)) {
|
||||
return Error(value_loc,
|
||||
StrCat("expects complex value for primitive type: ",
|
||||
PrimitiveType_Name(shape.element_type())));
|
||||
}
|
||||
if (!SetValueInLiteral(value_loc, value, index, literal)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unexpected element type: "
|
||||
<< PrimitiveType_Name(shape.element_type());
|
||||
}
|
||||
|
||||
if (lexer_.GetKind() != TokKind::kRbrace &&
|
||||
!ParseToken(TokKind::kComma,
|
||||
"expects ',' separator between sparse array elements")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (literal->sparse_element_count() + 1 ==
|
||||
LayoutUtil::MaxSparseElements(shape.layout())) {
|
||||
return Error(
|
||||
lexer_.GetLoc(),
|
||||
StrCat("number of sparse elements exceeds maximum for layout: ",
|
||||
ShapeUtil::HumanStringWithLayout(shape)));
|
||||
}
|
||||
}
|
||||
|
||||
literal->SortSparseElements();
|
||||
return true;
|
||||
}
|
||||
|
||||
// MaxFiniteValue is a type-traits helper used by
|
||||
// HloParserImpl::CheckParsedValueIsInRange.
|
||||
template <typename T>
|
||||
@ -3839,21 +3699,6 @@ bool HloParserImpl::ParseShape(Shape* result) {
|
||||
}
|
||||
LayoutUtil::SetToDefaultLayout(result);
|
||||
|
||||
if (lexer_.GetKind() == TokKind::kw_sparse) {
|
||||
lexer_.Lex();
|
||||
const std::string message =
|
||||
"expects a brace-bracketed integer for sparse layout";
|
||||
int64 max_sparse_elements;
|
||||
if (!ParseToken(TokKind::kLbrace, message) ||
|
||||
!ParseInt64(&max_sparse_elements) ||
|
||||
!ParseToken(TokKind::kRbrace, message)) {
|
||||
return false;
|
||||
}
|
||||
*result->mutable_layout() =
|
||||
LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
return true;
|
||||
}
|
||||
|
||||
// We need to lookahead to see if a following open brace is the start of a
|
||||
// layout. The specific problematic case is:
|
||||
//
|
||||
|
@ -841,50 +841,6 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
|
||||
)"
|
||||
},
|
||||
{
|
||||
"Sparse",
|
||||
R"(HloModule sparse_f32
|
||||
|
||||
ENTRY %sparse () -> f32[2,3,4] {
|
||||
ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
|
||||
}
|
||||
|
||||
)",
|
||||
/*enable_verification=*/false
|
||||
},
|
||||
{
|
||||
"SparseC128",
|
||||
R"(HloModule sparse_c128
|
||||
|
||||
ENTRY %sparse () -> c128[2,3,4] {
|
||||
ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
|
||||
}
|
||||
|
||||
)",
|
||||
/*enable_verification=*/false
|
||||
},
|
||||
{
|
||||
"SparseEmpty",
|
||||
R"(HloModule sparse_f32_empty
|
||||
|
||||
ENTRY %sparse_f32_empty () -> f32[2,3,4] {
|
||||
ROOT %foo = f32[2,3,4]sparse{10} constant({})
|
||||
}
|
||||
|
||||
)",
|
||||
/*enable_verification=*/false,
|
||||
},
|
||||
{
|
||||
"SparseR1",
|
||||
R"(HloModule sparse_f32_r1
|
||||
|
||||
ENTRY %sparse_f32_r1 () -> f32[9] {
|
||||
ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
|
||||
}
|
||||
|
||||
)",
|
||||
/*enable_verification=*/false,
|
||||
},
|
||||
{
|
||||
"Gather",
|
||||
R"(HloModule StringifyGather
|
||||
|
||||
@ -1982,17 +1938,6 @@ TEST_F(HloParserTest, ConstantBf16Overflow) {
|
||||
"out of range");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) {
|
||||
const string original = R"(
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65520})
|
||||
})";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"is out of range for literal's primitive type F16");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
|
||||
const string original = R"(
|
||||
HloModule ConstantUnsignedUnderflow_module
|
||||
@ -2852,50 +2797,6 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
|
||||
" with the shape of the operand instruction f32[2,2]{1,0}.");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, OutOfRangeSparseIndex) {
|
||||
const string original = R"(
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
ROOT c = f16[5]sparse{10} constant({[100]: 0})
|
||||
})";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"Invalid sparse index");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, NegativeSparseIndex) {
|
||||
const string original = R"(
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
ROOT c = f16[5]sparse{10} constant({-1: 0})
|
||||
})";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"Invalid sparse index");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, SparseIndexWithRankTooLarge) {
|
||||
const string original = R"(
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
ROOT c = f16[5]sparse{10} constant({[0, 0]: 0})
|
||||
})";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"Invalid sparse index");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, SparseIndexWithRankTooSmall) {
|
||||
const string original = R"(
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
ROOT c = f16[5, 5]sparse{10} constant({[0]: 0})
|
||||
})";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"Invalid sparse index");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringR2F32) {
|
||||
string shape_string = "f32[123,456]";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
@ -2994,15 +2895,6 @@ TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
|
||||
"Dimensions size is 3, but minor to major size is 1.");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) {
|
||||
string shape_string = "f32[123,456]sparse{10}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
|
||||
Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
|
||||
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
|
||||
<< "expected: " << ShapeUtil::HumanString(expected)
|
||||
<< "actual: " << ShapeUtil::HumanString(actual);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
|
||||
// Tile, element size, and memory space.
|
||||
string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
|
||||
@ -3047,10 +2939,8 @@ TEST_F(HloParserTest, ParseTokenType) {
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseInvalidShapeString) {
|
||||
string shape_strings[] = {
|
||||
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
|
||||
"f32[123,456]dense{foo}", "f32[123,456]sparse{foo}",
|
||||
};
|
||||
string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
|
||||
"f32[123,456]dense{foo}"};
|
||||
for (const string& shape_string : shape_strings) {
|
||||
StatusOr<Shape> result = ParseShape(shape_string);
|
||||
ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
|
||||
|
@ -33,17 +33,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status VerifyNotSparse(const Shape& shape) {
|
||||
return ShapeUtil::ForEachSubshapeWithStatus(
|
||||
shape, [](const Shape& subshape, const ShapeIndex&) -> Status {
|
||||
if (LayoutUtil::IsSparseArray(subshape)) {
|
||||
return InternalError("Sparse arrays are not yet fully supported: %s",
|
||||
ShapeUtil::HumanStringWithLayout(subshape));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
||||
bool IsCallerInstruction(HloInstruction* hlo) {
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kCall:
|
||||
@ -93,8 +82,6 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
|
||||
"Called computations specified for non-caller instruction %s",
|
||||
hlo->ToString());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape()));
|
||||
|
||||
absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
|
||||
if (arity) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
|
||||
@ -1109,8 +1096,6 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
|
||||
|
||||
TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape()));
|
||||
|
||||
if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
|
||||
result_layout.shape())) {
|
||||
return InternalError(
|
||||
@ -1131,7 +1116,6 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
|
||||
const HloInstruction* parameter = computation->parameter_instruction(i);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
|
||||
TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i)));
|
||||
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
|
||||
return InternalError(
|
||||
"Shape of the entry computation parameter %d is %s should be "
|
||||
|
@ -73,7 +73,7 @@ namespace xla {
|
||||
// - EqualTo
|
||||
// - CompatibleTo
|
||||
// - IsScalar/IsEffectiveScalar/IsArray/IsTuple
|
||||
// - IsDenseArray/IsSparseArray
|
||||
// - IsDenseArray
|
||||
// - WithLayout: layout shape's layout matches the given pattern (e.g.
|
||||
// Layout().WithDenseFormat())
|
||||
// - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
|
||||
@ -87,7 +87,7 @@ namespace xla {
|
||||
//
|
||||
// Layout():
|
||||
// - EqualTo
|
||||
// - WithDenseFormat/WithSparseFormat
|
||||
// - WithDenseFormat
|
||||
//
|
||||
// Op(), Shape(), and Layout() may be passed an argument of type
|
||||
// HloInstruction**, Shape**, or Layout**, respectively, or const versions of
|
||||
@ -506,12 +506,6 @@ class LayoutPattern {
|
||||
return AppendImpl(LayoutPatternFormatImpl(DENSE));
|
||||
}
|
||||
|
||||
// Modifies the pattern to match only if the layout has a sparse format.
|
||||
constexpr auto WithSparseFormat() const
|
||||
-> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
|
||||
return AppendImpl(LayoutPatternFormatImpl(SPARSE));
|
||||
}
|
||||
|
||||
private:
|
||||
Impl impl_;
|
||||
LayoutType** matched_layout_;
|
||||
@ -1060,11 +1054,6 @@ class ShapePattern {
|
||||
return WithLayout(Layout().WithDenseFormat());
|
||||
}
|
||||
|
||||
constexpr auto IsSparseArray() const
|
||||
-> decltype(this->WithLayout(Layout().WithSparseFormat())) {
|
||||
return WithLayout(Layout().WithSparseFormat());
|
||||
}
|
||||
|
||||
// Modifies the pattern to match only if the shape has a subshape that matches
|
||||
// the given pattern.
|
||||
template <typename SubshapeType, typename SubshapeImpl>
|
||||
|
@ -56,9 +56,6 @@ TEST(PatternMatcherGmock, MatchShape) {
|
||||
TEST(PatternMatcherGmock, MatchLayout) {
|
||||
Layout l = LayoutUtil::MakeLayout({0, 1});
|
||||
EXPECT_THAT(l, GmockMatch(m::Layout()));
|
||||
EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat())));
|
||||
EXPECT_THAT(Describe<Layout>(GmockMatch(m::Layout().WithSparseFormat())),
|
||||
"a layout with format SPARSE");
|
||||
}
|
||||
|
||||
TEST(PatternMatchGmock, MatchInstruction) {
|
||||
|
@ -89,7 +89,6 @@ TEST_F(PatternMatcherTest, DenseArrayShape) {
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
|
||||
EXPECT_EQ(matched_shape, &array_shape);
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
|
||||
EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray()));
|
||||
EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
|
||||
EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
|
||||
@ -97,38 +96,12 @@ TEST_F(PatternMatcherTest, DenseArrayShape) {
|
||||
EXPECT_FALSE(
|
||||
Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
|
||||
Layout* matched_layout;
|
||||
EXPECT_FALSE(Match(&array_shape,
|
||||
match::Shape().WithLayout(
|
||||
match::Layout(&matched_layout).WithSparseFormat())));
|
||||
EXPECT_TRUE(Match(&array_shape,
|
||||
match::Shape().WithLayout(
|
||||
match::Layout(&matched_layout).WithDenseFormat())));
|
||||
EXPECT_EQ(matched_layout, &array_shape.layout());
|
||||
}
|
||||
|
||||
TEST_F(PatternMatcherTest, SparseArrayShape) {
|
||||
auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10);
|
||||
Shape* matched_shape;
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
|
||||
EXPECT_EQ(matched_shape, &array_shape);
|
||||
EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray()));
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray()));
|
||||
EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
|
||||
EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
|
||||
EXPECT_FALSE(
|
||||
Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
|
||||
Layout* matched_layout;
|
||||
EXPECT_FALSE(Match(&array_shape,
|
||||
match::Shape().WithLayout(
|
||||
match::Layout(&matched_layout).WithDenseFormat())));
|
||||
EXPECT_TRUE(Match(&array_shape,
|
||||
match::Shape().WithLayout(
|
||||
match::Layout(&matched_layout).WithSparseFormat())));
|
||||
EXPECT_EQ(matched_layout, &array_shape.layout());
|
||||
}
|
||||
|
||||
TEST_F(PatternMatcherTest, TupleShape) {
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1, 2, 3}),
|
||||
@ -568,15 +541,6 @@ TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
||||
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
|
||||
"a layout equal to {1,2}",
|
||||
"Layout {2,2} is not equal to expected {1,2}");
|
||||
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(),
|
||||
"a layout with format SPARSE",
|
||||
"Layout has format DENSE but expected SPARSE");
|
||||
EXPECT_DESC_AND_EXPLANATION(layout,
|
||||
m::Layout().EqualTo(&layout).WithSparseFormat(),
|
||||
"a layout:\n"
|
||||
" * equal to {1,2} AND\n"
|
||||
" * with format SPARSE",
|
||||
"Layout has format DENSE but expected SPARSE");
|
||||
}
|
||||
|
||||
TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||
@ -665,11 +629,6 @@ TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
||||
"a shape with\n a layout equal to {0,1}",
|
||||
"Layout {1,0} is not equal to expected {0,1}\n"
|
||||
"in f32[1,2]{1,0}");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()),
|
||||
"a shape with\n a layout with format SPARSE",
|
||||
"Layout has format DENSE but expected SPARSE\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape,
|
||||
m::Shape().WithSubshapeEqualTo({10}, &shape),
|
||||
"a shape with subshape at index {10} which is\n"
|
||||
|
@ -229,16 +229,6 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return MakeShapeWithLayout(element_type, dimensions, layout);
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
||||
int64 max_sparse_elements) {
|
||||
CHECK(IsArrayPrimitiveType(element_type));
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ Shape
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
const Shape& shape) {
|
||||
@ -637,9 +627,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return ByteSizeOfTupleIndexTable(shape, pointer_size);
|
||||
} else if (shape.IsArray()) {
|
||||
int64 byte_size = ByteSizeOfElements(shape);
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
byte_size += ByteSizeOfSparseIndices(shape);
|
||||
}
|
||||
return byte_size;
|
||||
} else if (shape.element_type() == TOKEN) {
|
||||
return 0;
|
||||
@ -664,23 +651,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
CHECK(shape.IsArray());
|
||||
int64 allocated_element_count;
|
||||
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
|
||||
} else {
|
||||
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
|
||||
allocated_element_count = ElementsIn(shape);
|
||||
}
|
||||
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
|
||||
allocated_element_count = ElementsIn(shape);
|
||||
return allocated_element_count *
|
||||
ByteSizeOfPrimitiveType(shape.element_type());
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
CHECK(LayoutUtil::IsSparseArray(shape));
|
||||
return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() *
|
||||
sizeof(int64);
|
||||
}
|
||||
|
||||
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
|
||||
const Shape& shape) {
|
||||
if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
|
||||
@ -721,9 +697,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) {
|
||||
return InvalidArgument("sparse arrays must have rank > 0");
|
||||
}
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
int64 dimension = shape.dimensions(i);
|
||||
if (dimension < 0) {
|
||||
@ -744,43 +717,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We can only reason about some aspects of array's shape if it has a valid
|
||||
// layout, these aspects will be ignored otherwise.
|
||||
bool shape_has_valid_layout = LayoutUtil::HasLayout(shape) &&
|
||||
LayoutUtil::ValidateLayoutInShape(shape).ok();
|
||||
|
||||
int64 shape_size = [&]() {
|
||||
if (shape_has_valid_layout && LayoutUtil::IsSparseArray(shape)) {
|
||||
int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout());
|
||||
if (max_sparse_elements < 0) {
|
||||
return max_sparse_elements;
|
||||
}
|
||||
int64 sparse_elements_size = MultiplyWithoutOverflow(
|
||||
max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type()));
|
||||
if (sparse_elements_size < 0) {
|
||||
return sparse_elements_size;
|
||||
}
|
||||
int64 sparse_indices_size =
|
||||
MultiplyWithoutOverflow(max_sparse_elements, shape.rank());
|
||||
if (sparse_indices_size < 0) {
|
||||
return sparse_indices_size;
|
||||
}
|
||||
sparse_indices_size =
|
||||
MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64));
|
||||
if (sparse_indices_size < 0) {
|
||||
return sparse_indices_size;
|
||||
}
|
||||
// At this point, both sparse_indices_size and sparse_elements_size are
|
||||
// non-negative, so we can easily check if adding them wraps.
|
||||
if (static_cast<uint64>(sparse_elements_size) +
|
||||
static_cast<uint64>(sparse_indices_size) >
|
||||
INT64_MAX) {
|
||||
return static_cast<int64>(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// This is intentionally unconditional: even if the shape is sparse, we want
|
||||
// to verify the densified version has a reasonable size.
|
||||
int64 dense_shape_size = 1;
|
||||
if (shape.dimensions().empty()) {
|
||||
return dense_shape_size;
|
||||
|
@ -192,10 +192,7 @@ class ShapeUtil {
|
||||
};
|
||||
|
||||
// Returns the number of elements are contained within the provided shape;
|
||||
// e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
|
||||
// may not actually be able to store this number of elements. See
|
||||
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
|
||||
// elements that can be stored in a sparse shape.
|
||||
// e.g. for rank 0 (scalars) the result is always 1.
|
||||
// Precondition: shape.IsArray()
|
||||
static int64 ElementsIn(const Shape& shape);
|
||||
|
||||
@ -228,20 +225,12 @@ class ShapeUtil {
|
||||
int64 pointer_size);
|
||||
|
||||
// Returns the number of bytes required for the elements in an allocation of
|
||||
// `shape`, which must be an array shape. The return value does not include
|
||||
// the bytes needed to store sparse indices. Dense shapes use a separate
|
||||
// `shape`, which must be an array shape. Shapes use a separate
|
||||
// memory location for each element, and so for these shapes,
|
||||
// `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
|
||||
// size also includes padding if present in the layout. For sparse shapes,
|
||||
// `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
|
||||
// ByteSizeOfSparseindices(shape)`.
|
||||
// `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
|
||||
// size also includes padding if present in the layout.
|
||||
static int64 ByteSizeOfElements(const Shape& shape);
|
||||
|
||||
// Returns the number of bytes required for the sparse indices in an
|
||||
// allocation of shape. The shape must be an array shape. The return value
|
||||
// does not include the bytes needed to store sparse indices.
|
||||
static int64 ByteSizeOfSparseIndices(const Shape& shape);
|
||||
|
||||
// Returns a human-readable string that represents the given shape, with or
|
||||
// without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
|
||||
static string HumanString(const Shape& shape);
|
||||
@ -427,9 +416,6 @@ class ShapeUtil {
|
||||
int64 element_size_in_bits = 0,
|
||||
int64 memory_space = 0);
|
||||
|
||||
static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions,
|
||||
int64 max_sparse_elements);
|
||||
// Returns the same shape except with all dimensions set to be static.
|
||||
static Shape MakeShapeWithStaticDimensions(const Shape& shape);
|
||||
|
||||
|
@ -1,109 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {}
|
||||
|
||||
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
|
||||
std::vector<int64> indices)
|
||||
: indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_EQ(indices_.size() % rank_, 0)
|
||||
<< "indices_.size(): " << indices_.size() << ", rank_: " << rank_;
|
||||
CHECK_LE(index_count(), max_indices_);
|
||||
}
|
||||
|
||||
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
|
||||
absl::Span<const int64> indices)
|
||||
: SparseIndexArray(max_indices, rank,
|
||||
std::vector<int64>(indices.begin(), indices.end())) {}
|
||||
|
||||
SparseIndexArray::SparseIndexArray(int64 max_indices,
|
||||
const Array2D<int64>& indices)
|
||||
: SparseIndexArray(max_indices, indices.n2(),
|
||||
std::vector<int64>(indices.begin(), indices.end())) {}
|
||||
|
||||
int64 SparseIndexArray::index_count() const {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_EQ(indices_.size() % rank_, 0);
|
||||
return indices_.size() / rank_;
|
||||
}
|
||||
|
||||
absl::Span<const int64> SparseIndexArray::At(
|
||||
int64 sparse_element_number) const {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_GE(sparse_element_number, 0);
|
||||
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
|
||||
return absl::Span<const int64>(
|
||||
indices_.data() + rank_ * sparse_element_number, rank_);
|
||||
}
|
||||
|
||||
absl::Span<int64> SparseIndexArray::At(int64 sparse_element_number) {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_GE(sparse_element_number, 0);
|
||||
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
|
||||
return absl::Span<int64>(indices_.data() + rank_ * sparse_element_number,
|
||||
rank_);
|
||||
}
|
||||
|
||||
void SparseIndexArray::Append(absl::Span<const int64> index) {
|
||||
CHECK_GT(rank_, 0);
|
||||
CHECK_EQ(index.size(), rank_);
|
||||
indices_.insert(indices_.end(), index.begin(), index.end());
|
||||
}
|
||||
|
||||
void SparseIndexArray::Clear() { indices_.clear(); }
|
||||
|
||||
void SparseIndexArray::Resize(int64 num_indices) {
|
||||
CHECK_GT(rank_, 0);
|
||||
indices_.resize(rank_ * num_indices);
|
||||
}
|
||||
|
||||
bool SparseIndexArray::Validate(const Shape& shape) const {
|
||||
if (rank_ == 0 || rank_ != shape.rank()) {
|
||||
return false;
|
||||
}
|
||||
int64 num_indices = index_count();
|
||||
if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) {
|
||||
return false;
|
||||
}
|
||||
if (num_indices < 2) {
|
||||
return true;
|
||||
}
|
||||
absl::Span<const int64> last = At(0);
|
||||
if (!IndexUtil::IndexInBounds(shape, last)) {
|
||||
return false;
|
||||
}
|
||||
for (int64 n = 1; n < num_indices; ++n) {
|
||||
absl::Span<const int64> next = At(n);
|
||||
if (!IndexUtil::IndexInBounds(shape, next)) {
|
||||
return false;
|
||||
}
|
||||
if (IndexUtil::CompareIndices(last, next) >= 0) {
|
||||
return false;
|
||||
}
|
||||
last = next;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,176 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Utility class for managing sparse array indices.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Encapsulates the array of indices for a sparse array. A SparseIndexArray
|
||||
// contain indices for up to `max_indices` elements of a sparse array. Each
|
||||
// sparse index is an array of `rank` int64 value that gives the location of a
|
||||
// value within a sparse array. Note that the dimensions of the array are not
|
||||
// checked (except for the rank). To avoid confusion, we refer to the position
|
||||
// of an index within a SparseIndexArray as a sparse index number.
|
||||
class SparseIndexArray {
|
||||
public:
|
||||
SparseIndexArray();
|
||||
SparseIndexArray(const SparseIndexArray&) = default;
|
||||
SparseIndexArray(SparseIndexArray&&) = default;
|
||||
SparseIndexArray& operator=(const SparseIndexArray&) = default;
|
||||
SparseIndexArray& operator=(SparseIndexArray&&) = default;
|
||||
|
||||
// Constructs a SparseIndexArray that can hold up to `max_indices` sparse
|
||||
// indices, with an initial contents obtained from the given array. The rank
|
||||
// is taken from the minor dimension of the array. The major dimension of the
|
||||
// array must not exceed `max_indices`.
|
||||
SparseIndexArray(int64 max_indices, const Array2D<int64>& indices);
|
||||
|
||||
// Like above, but the array is flattened. For example, the following are
|
||||
// equivalent:
|
||||
//
|
||||
// SparseIndexArray(10, 3,
|
||||
// Array2D{
|
||||
// {0, 1, 2},
|
||||
// {3, 4, 5},
|
||||
// {6, 7, 8},
|
||||
// {9, 10, 11},
|
||||
// })
|
||||
//
|
||||
// SparseIndexArray(10, 3,
|
||||
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
|
||||
//
|
||||
SparseIndexArray(int64 max_indices, int64 rank,
|
||||
std::vector<int64> indices = {});
|
||||
SparseIndexArray(int64 max_indices, int64 rank,
|
||||
absl::Span<const int64> indices);
|
||||
|
||||
// Returns the number of elements represented by the indices stored in the
|
||||
// array.
|
||||
int64 index_count() const;
|
||||
|
||||
// Returns a slice that refers to the given sparse index number. The argument
|
||||
// must be in the range [0, element_count()).
|
||||
absl::Span<const int64> At(int64 sparse_element_number) const;
|
||||
absl::Span<int64> At(int64 sparse_element_number);
|
||||
|
||||
// Adds the given index at the end of the array. The new size of the
|
||||
// SparseIndexArray must not exceed `max_indices`.
|
||||
void Append(absl::Span<const int64> index);
|
||||
|
||||
// Removes all indices from the array.
|
||||
void Clear();
|
||||
|
||||
// Resizes the array to contain the given number of sparse indices. The new
|
||||
// size must be smaller than `max_indices`. If the new size is larger than
|
||||
// the old size, the value of the new indices is not specified.
|
||||
void Resize(int64 num_indices);
|
||||
|
||||
// Returns true iff all indices are unique and occur in sorted order, and are
|
||||
// valid for the given shape.
|
||||
bool Validate(const Shape& shape) const;
|
||||
|
||||
int64 rank() const { return rank_; }
|
||||
int64 max_indices() const { return max_indices_; }
|
||||
|
||||
// Returns a pointer to the int64 array that holds the sparse indices.
|
||||
absl::Span<int64> mutable_data() { return absl::MakeSpan(indices_); }
|
||||
absl::Span<const int64> data() const { return indices_; }
|
||||
|
||||
// Sorts this sparse index array along with the set of corresponding values.
|
||||
// The indices and values are sorted in the lexicographic order of the
|
||||
// indices, from smallest to largest.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// std::vector<float> v{10.0, 11.0, 12.0};
|
||||
// SparseIndexArray a(10, 3,
|
||||
// {{3, 4, 5},
|
||||
// {1, 2, 3},
|
||||
// {2, 3, 4}});
|
||||
// a.SortWithValues(&v);
|
||||
// // Prints "11.0, 12.0, 10.0":
|
||||
// std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
|
||||
//
|
||||
template <typename NativeT>
|
||||
void SortWithValues(absl::Span<NativeT> values);
|
||||
|
||||
private:
|
||||
std::vector<int64> indices_;
|
||||
int64 rank_;
|
||||
int64 max_indices_;
|
||||
};
|
||||
|
||||
template <typename NativeT>
|
||||
void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
|
||||
int64 num_elements = index_count();
|
||||
CHECK_EQ(values.size(), num_elements);
|
||||
std::vector<int64> sort_order;
|
||||
sort_order.reserve(num_elements);
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
sort_order.push_back(i);
|
||||
}
|
||||
auto sort_order_less = [this](int64 lhs, int64 rhs) {
|
||||
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
|
||||
};
|
||||
absl::c_sort(sort_order, sort_order_less);
|
||||
|
||||
// Reorder the array elements according to sort_order. Work through the array
|
||||
// and follow cycles so we can do the reorder in-place.
|
||||
absl::InlinedVector<int64, 8> saved_index(rank());
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
// sort_order[i] == -1 indicates the element has already been copied.
|
||||
if (sort_order[i] < 0) {
|
||||
continue;
|
||||
} else if (i == sort_order[i]) {
|
||||
// The element is already in sorted order.
|
||||
sort_order[i] = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::copy_n(At(i).begin(), rank(), saved_index.begin());
|
||||
NativeT saved_value = values[i];
|
||||
int64 j = i;
|
||||
for (;;) {
|
||||
if (sort_order[j] == i) {
|
||||
std::copy_n(saved_index.begin(), rank(), At(j).begin());
|
||||
values[j] = saved_value;
|
||||
sort_order[j] = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin());
|
||||
values[j] = values[sort_order[j]];
|
||||
|
||||
int64 k = sort_order[j];
|
||||
sort_order[j] = -1;
|
||||
j = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/sparse_index_array.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(SparseIndexArrayTest, Sort) {
|
||||
SparseIndexArray a(10, 3);
|
||||
a.Append({2, 3, 4});
|
||||
a.Append({3, 4, 5});
|
||||
a.Append({1, 2, 3});
|
||||
a.Append({5, 6, 7});
|
||||
a.Append({4, 5, 6});
|
||||
a.Append({6, 7, 8});
|
||||
std::vector<double> values = {
|
||||
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
|
||||
};
|
||||
a.SortWithValues<double>(absl::MakeSpan(values));
|
||||
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
|
||||
6, 7, 6, 7, 8}));
|
||||
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -115,9 +115,8 @@ enum Format {
|
||||
INVALID_FORMAT = 0;
|
||||
// The default layout, with exactly one storage location per element.
|
||||
DENSE = 1;
|
||||
// A sparsely encoded layout, providing only the index/value pairs of non-zero
|
||||
// elements.
|
||||
SPARSE = 2;
|
||||
reserved 2;
|
||||
reserved "SPARSE";
|
||||
}
|
||||
|
||||
// Describes a tile used in tiling-based layout. Refer to
|
||||
@ -156,10 +155,8 @@ message LayoutProto {
|
||||
reserved 3;
|
||||
reserved "padding_value";
|
||||
|
||||
// The maximum number of elements that can be stored for SPARSE formats. This
|
||||
// can be used to determine the maximum size in bytes of arrays stored in
|
||||
// memory. This field must be unset unless the format is SPARSE.
|
||||
int64 max_sparse_elements = 5;
|
||||
reserved 5;
|
||||
reserved "max_sparse_elements";
|
||||
|
||||
// A sequence of tiles, starting from the tile that's applied first to the
|
||||
// Shape.
|
||||
|
Loading…
Reference in New Issue
Block a user