[XLA] Remove unsupported sparse layout

Sparse layouts are not supported on any of the backends.

For backwards compatibility the fields stay in the protobuf, but parsing them
is a no-op.

PiperOrigin-RevId: 287924498
Change-Id: I8b1c1ec52e3a423015837bc10deee832921ba66c
This commit is contained in:
George Karpenkov 2020-01-02 18:01:40 -08:00 committed by TensorFlower Gardener
parent 1a416ed6a5
commit 2c431b6169
36 changed files with 45 additions and 1849 deletions

View File

@ -417,7 +417,6 @@ cc_library(
":array3d",
":array4d",
":shape_util",
":sparse_index_array",
":status_macros",
":types",
":util",
@ -463,7 +462,6 @@ cc_library(
":array4d",
":literal",
":shape_util",
":sparse_index_array",
":status_macros",
":types",
":util",
@ -840,29 +838,6 @@ tf_cc_test(
],
)
cc_library(
name = "sparse_index_array",
srcs = ["sparse_index_array.cc"],
hdrs = ["sparse_index_array.h"],
deps = [
":array2d",
":shape_util",
":xla_data_proto_cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "sparse_index_array_test",
srcs = ["sparse_index_array_test.cc"],
deps = [
":sparse_index_array",
":test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "parse_flags_from_env",
srcs = ["parse_flags_from_env.cc"],

View File

@ -52,7 +52,6 @@ string Tile::ToString() const {
for (const int64 dimension : proto.minor_to_major()) {
layout.add_minor_to_major(dimension);
}
layout.set_max_sparse_elements(proto.max_sparse_elements());
for (const TileProto& tile_proto : proto.tiles()) {
*layout.add_tiles() = Tile::CreateFromProto(tile_proto);
}
@ -68,7 +67,6 @@ LayoutProto Layout::ToProto() const {
for (const int64 dimension : minor_to_major()) {
proto.add_minor_to_major(dimension);
}
proto.set_max_sparse_elements(max_sparse_elements_);
for (const Tile& tile : tiles()) {
*proto.add_tiles() = tile.ToProto();
}
@ -78,10 +76,7 @@ LayoutProto Layout::ToProto() const {
}
string Layout::ToString() const {
if (format() == SPARSE) {
CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled.";
return absl::StrCat("sparse{", max_sparse_elements(), "}");
} else if (format() == DENSE) {
if (format() == DENSE) {
string colon_string = tiles().empty() ? "" : "T";
for (Tile tile : tiles()) {
absl::StrAppend(&colon_string, tile.ToString());
@ -107,10 +102,6 @@ bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) {
return false;
}
if (lhs.format() == SPARSE &&
lhs.max_sparse_elements() != rhs.max_sparse_elements()) {
return false;
}
if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) {
return false;
}

View File

@ -203,12 +203,6 @@ class Layout {
absl::Span<const Tile> tiles() const { return tiles_; }
absl::InlinedVector<Tile, 2>* mutable_tiles() { return &tiles_; }
// Methods for accessing the int64 fields.
int64 max_sparse_elements() const { return max_sparse_elements_; }
Layout& set_max_sparse_elements(int64 value) {
max_sparse_elements_ = value;
return *this;
}
int64 element_size_in_bits() const { return element_size_in_bits_; }
Layout& set_element_size_in_bits(int64 value) {
element_size_in_bits_ = value;
@ -233,8 +227,7 @@ class Layout {
template <typename H>
friend H AbslHashValue(H h, const Layout& l) {
return H::combine(std::move(h), l.format_, l.minor_to_major_,
l.max_sparse_elements_, l.tiles_,
return H::combine(std::move(h), l.format_, l.minor_to_major_, l.tiles_,
l.element_size_in_bits_);
}
@ -255,11 +248,6 @@ class Layout {
// And the major dim is [8,100,100,3][1], which is size 100.
absl::InlinedVector<int64, 6> minor_to_major_;
// The maximum number of elements that can be stored for SPARSE formats. This
// can be used to determine the maximum size in bytes of arrays stored in
// memory. This field must be zero unless the format is SPARSE.
int64 max_sparse_elements_ = 0;
// The tiles used in tiling-based layout.
absl::InlinedVector<Tile, 2> tiles_;

View File

@ -34,8 +34,6 @@ class LayoutTest : public ::testing::Test {};
TEST_F(LayoutTest, ToString) {
EXPECT_EQ(Layout().ToString(), "invalid{}");
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(),
"sparse{123}");
EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}");
EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(),
"{3,2,1,0:T(42,123)(4,5)}");
@ -65,11 +63,6 @@ TEST_F(LayoutTest, StreamOut) {
}
}
TEST_F(LayoutTest, SparseLayoutMaxElements) {
EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)),
101);
}
TEST_F(LayoutTest, Equality) {
EXPECT_EQ(Layout(), Layout());
const std::vector<int64> empty_dims;
@ -90,12 +83,6 @@ TEST_F(LayoutTest, Equality) {
Layout({0, 1, 2}).set_memory_space(3));
EXPECT_NE(Layout({0, 1, 2}).set_memory_space(1),
Layout({0, 1, 2}).set_memory_space(3));
EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE));
EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42),
Layout().set_format(SPARSE).set_max_sparse_elements(42));
EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42),
Layout().set_format(SPARSE).set_max_sparse_elements(24));
EXPECT_FALSE(
Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2})));
EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}),
@ -117,8 +104,6 @@ TEST_F(LayoutTest, LayoutToFromProto) {
expect_unchanged(Layout());
expect_unchanged(Layout({1, 3, 2, 0}));
expect_unchanged(Layout().set_format(SPARSE));
expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123));
expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42));
expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}));
}

View File

@ -94,13 +94,6 @@ void SetDefaultLayoutToContainer(T* minor_to_major) {
return layout;
}
/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
Layout layout;
layout.set_format(SPARSE);
layout.set_max_sparse_elements(max_sparse_elements);
return layout;
}
namespace {
// Internal helper that creates a default layout for an array of the given rank.
@ -293,19 +286,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
layout.minor_to_major().end(), std::greater<int64>());
}
/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
}
/* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
return layout.format() == SPARSE;
}
/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) {
CHECK(IsSparse(layout));
return layout.max_sparse_elements();
}
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
if (shape.IsTuple()) {
// Tuple shape: all subshapes must have a layout.
@ -461,8 +441,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
for (int64 minor_to_major : layout.minor_to_major()) {
hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
}
hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
for (Tile tile : layout.tiles()) {
for (int64 tile_dim : tile.dimensions()) {
hash_value = Hash64Combine(hash_value, hash<int64>()(tile_dim));

View File

@ -49,10 +49,6 @@ class LayoutUtil {
// dimensions.
static Layout MakeDescendingLayout(int64 rank);
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
static Layout MakeSparseLayout(int64 max_sparse_elements);
// Returns default layout for the given shape.
static Layout GetDefaultLayoutForShape(const Shape& shape);
@ -109,17 +105,6 @@ class LayoutUtil {
// more minor, and so on until dimension N-1 which is the minor.
static bool IsMonotonicWithDim0Major(const Layout& layout);
// Returns whether the given Shape is an array (i.e. not a tuple) and has a
// sparse format layout.
static bool IsSparseArray(const Shape& shape);
// Returns whether the given Layout has a sparse format.
static bool IsSparse(const Layout& layout);
// Returns the maximum number of elements that can be stored in a sparse
// layout.
static int64 MaxSparseElements(const Layout& layout);
// Returns whether the given shape has a layout. For tuple shapes, true is
// returned only if all elements have layouts.
static bool HasLayout(const Shape& shape);

View File

@ -33,14 +33,6 @@ class LayoutUtilTest : public ::testing::Test {
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
return shape;
}
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
absl::Span<const int64> dimensions,
int64 max_sparse_elements) {
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
return shape;
}
};
TEST_F(LayoutUtilTest, TupleLayoutComparison) {
@ -92,29 +84,6 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) {
EXPECT_FALSE(dst.has_layout());
}
TEST_F(LayoutUtilTest, CopyLayoutSparse) {
Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2);
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
// Should work if destination has no layout.
dst.clear_layout();
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
// If source is cleared, then destination should be cleared.
src.clear_layout();
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_TRUE(dst.has_layout());
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_FALSE(dst.has_layout());
}
TEST_F(LayoutUtilTest, CopyLayoutTuple) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
@ -134,25 +103,6 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) {
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithSparseLayout(F32, {2, 3}, 4),
MakeShapeWithSparseLayout(F32, {42, 123}, 4),
ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})});
Shape dst = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
@ -160,13 +110,6 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) {
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) {
Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6);
Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0});
ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@ -176,15 +119,6 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) {
::testing::ContainsRegex("cannot copy layout from shape"));
}
TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) {
Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4);
auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
::testing::ContainsRegex("cannot copy layout from shape"));
}
TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
Shape src =
ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),

View File

@ -80,7 +80,7 @@ bool LiteralProtoHasValues(const LiteralProto& proto) {
proto.c64s_size() || proto.c128s_size() ||
proto.tuple_literals_size() || !proto.f16s().empty() ||
!proto.bf16s().empty() || !proto.u16s().empty() ||
!proto.s16s().empty() || proto.sparse_indices_size();
!proto.s16s().empty();
}
} // namespace
@ -135,21 +135,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
// Literals can be used as DMA targets, which can require alignment. We
// force a 16-byte minimum alignment.
constexpr int kMinimumAlignment = 16;
if (LayoutUtil::IsSparseArray(shape)) {
// For sparse arrays, the buffer must be of the size of the maximum
// number of sparse elements possible.
const int64 max_sparse_elements =
LayoutUtil::MaxSparseElements(shape.layout());
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
max_sparse_elements *
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()),
kMinimumAlignment)));
piece->set_sparse_indices(
new SparseIndexArray(max_sparse_elements, shape.rank()));
} else {
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
piece->size_bytes(), kMinimumAlignment)));
}
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
piece->size_bytes(), kMinimumAlignment)));
}
} else {
// If the shape is neither an array nor tuple, then it must be
@ -181,7 +168,6 @@ void Literal::DeallocateBuffers() {
[&](const ShapeIndex& index, Piece* piece) {
if (piece->buffer() != nullptr) {
tensorflow::port::AlignedFree(piece->buffer());
delete piece->sparse_indices();
}
});
}
@ -211,16 +197,6 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) {
return literal;
}
const SparseIndexArray* LiteralBase::sparse_indices(
const ShapeIndex& shape_index) const {
return piece(shape_index).sparse_indices();
}
SparseIndexArray* MutableLiteralBase::sparse_indices(
const ShapeIndex& shape_index) {
return piece(shape_index).sparse_indices();
}
template <typename NativeT>
Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, absl::Span<const int64> src_base,
@ -373,12 +349,9 @@ std::vector<Literal> Literal::DecomposeTuple() {
}
Piece& src_piece = piece(src_index);
// Move the respective buffer and sparse indices over to the element
// Literal.
// Move the respective buffer over to the element Literal.
dest_piece->set_buffer(src_piece.buffer());
src_piece.set_buffer(nullptr);
dest_piece->set_sparse_indices(src_piece.sparse_indices());
src_piece.set_sparse_indices(nullptr);
});
}
// Set this literal to be nil-shaped.
@ -512,8 +485,6 @@ Status Literal::MoveFrom(Literal&& src_literal,
Piece& dest_piece = piece(dest_index);
tensorflow::port::AlignedFree(dest_piece.buffer());
dest_piece.set_buffer(src_piece.buffer());
delete dest_piece.sparse_indices();
dest_piece.set_sparse_indices(src_piece.sparse_indices());
});
src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
@ -854,66 +825,6 @@ string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
}
}
string LiteralBase::GetSparseElementAsString(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsSparseArray(subshape));
switch (subshape.element_type()) {
case PRED:
return GetSparseElement<bool>(sparse_element_number, shape_index)
? "true"
: "false";
case S8:
return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
case S16:
return StrCat(
GetSparseElement<int16>(sparse_element_number, shape_index));
case S32:
return StrCat(
GetSparseElement<int32>(sparse_element_number, shape_index));
case S64:
return StrCat(
GetSparseElement<int64>(sparse_element_number, shape_index));
case U8:
return StrCat(
GetSparseElement<uint8>(sparse_element_number, shape_index));
case U16:
return StrCat(
GetSparseElement<uint16>(sparse_element_number, shape_index));
case U32:
return StrCat(
GetSparseElement<uint32>(sparse_element_number, shape_index));
case U64:
return StrCat(
GetSparseElement<uint64>(sparse_element_number, shape_index));
case F16:
return StrCat(static_cast<float>(
GetSparseElement<half>(sparse_element_number, shape_index)));
case F32:
return StrCat(
GetSparseElement<float>(sparse_element_number, shape_index));
case BF16:
return StrCat(static_cast<float>(
GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
case F64:
return StrCat(
GetSparseElement<double>(sparse_element_number, shape_index));
case C64: {
complex64 c =
GetSparseElement<complex64>(sparse_element_number, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
case C128: {
complex128 c =
GetSparseElement<complex128>(sparse_element_number, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
default:
LOG(FATAL) << "Invalid element type for sparse arrays: "
<< PrimitiveType_Name(subshape.element_type());
}
}
absl::optional<int64> LiteralBase::GetIntegralAsS64(
absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
@ -1047,81 +958,6 @@ Status MutableLiteralBase::SetFromDouble(absl::Span<const int64> multi_index,
return Status::OK();
}
absl::Span<const int64> LiteralBase::GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Piece& p = piece(shape_index);
CHECK_GE(sparse_element_number, 0);
CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
return p.sparse_indices()->At(sparse_element_number);
}
void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
piece(shape_index).SortSparseElements();
}
void LiteralBase::Piece::SortSparseElements() {
switch (subshape().element_type()) {
case PRED:
SortSparseElementsInternal<bool>();
break;
case S8:
SortSparseElementsInternal<int8>();
break;
case U8:
SortSparseElementsInternal<uint8>();
break;
case S16:
SortSparseElementsInternal<int16>();
break;
case U16:
SortSparseElementsInternal<uint16>();
break;
case S32:
SortSparseElementsInternal<int32>();
break;
case U32:
SortSparseElementsInternal<uint32>();
break;
case S64:
SortSparseElementsInternal<int64>();
break;
case U64:
SortSparseElementsInternal<uint64>();
break;
case F32:
SortSparseElementsInternal<float>();
break;
case F64:
SortSparseElementsInternal<double>();
break;
case C64:
SortSparseElementsInternal<complex64>();
break;
case C128:
SortSparseElementsInternal<complex128>();
break;
case F16:
SortSparseElementsInternal<half>();
break;
case BF16:
SortSparseElementsInternal<bfloat16>();
break;
default:
LOG(FATAL) << "Element type not valid for sparse array: "
<< PrimitiveType_Name(subshape().element_type());
}
}
template <typename NativeT>
void LiteralBase::Piece::SortSparseElementsInternal() {
CHECK(LayoutUtil::IsSparseArray(subshape()));
int64 num_elements = sparse_indices()->index_count();
auto values = data<NativeT>();
CHECK_LE(num_elements, values.size());
sparse_indices()->SortWithValues(
absl::Span<NativeT>(values.data(), num_elements));
}
namespace {
string ShapeToString(bool print_layout, const Shape& shape) {
@ -1151,32 +987,6 @@ void TupleToStringHelper(const LiteralBase& literal,
pieces->push_back("\n)");
}
void SparseArrayToStringHelper(const LiteralBase& literal,
const Shape& subshape, bool print_shape,
bool print_layout, std::vector<string>* pieces) {
if (print_shape) {
pieces->push_back(ShapeToString(print_layout, subshape));
}
pieces->push_back("{");
int64 rank = subshape.rank();
int64 num_elements = literal.sparse_element_count();
for (int64 i = 0; i < num_elements; ++i) {
if (i > 0) {
pieces->push_back(", ");
}
if (rank == 1) {
pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
pieces->push_back(": ");
} else {
pieces->push_back("[");
pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", "));
pieces->push_back("]: ");
}
pieces->push_back(literal.GetSparseElementAsString(i));
}
pieces->push_back("}");
}
void DenseArrayToStringHelper(const LiteralBase& literal,
const ShapeIndex& shape_index, bool print_shape,
bool print_layout, std::vector<string>* pieces) {
@ -1261,9 +1071,6 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces);
} else if (subshape.IsToken()) {
pieces->push_back("token");
} else if (LayoutUtil::IsSparseArray(subshape)) {
SparseArrayToStringHelper(literal, subshape, print_shape, print_layout,
pieces);
} else {
CHECK(LayoutUtil::IsDenseArray(subshape));
DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout,
@ -1273,11 +1080,6 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
} // namespace
int64 LiteralBase::sparse_element_count() const {
CHECK(LayoutUtil::IsSparseArray(shape()));
return sparse_indices()->index_count();
}
string LiteralBase::ToString() const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
@ -2053,22 +1855,6 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
TF_RET_CHECK(LayoutUtil::HasLayout(shape));
TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
if (LayoutUtil::IsSparseArray(subshape())) {
// Compute the number of elements (indices) in the sparse shape and reserve
// the necessary space in spare_indices.
TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse";
TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0)
<< "Unexpected number of indices in proto ("
<< proto.sparse_indices_size() << ") for shape of rank "
<< subshape().rank();
const int64 index_count = proto.sparse_indices_size() / subshape().rank();
sparse_indices()->Resize(index_count);
// Copy the indices from the proto into the SparseIndexArray object.
TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
proto.sparse_indices()));
}
switch (subshape().element_type()) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
@ -2175,11 +1961,6 @@ LiteralProto LiteralBase::ToProto() const {
piece.WriteToProto(proto_piece);
});
if (LayoutUtil::IsSparseArray(shape())) {
CopyToRepeatedField(proto.mutable_sparse_indices(),
sparse_indices()->data());
}
return proto;
}
@ -2295,12 +2076,6 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
MutableBorrowingLiteral::~MutableBorrowingLiteral() {
if (root_piece_ != nullptr) {
root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (piece->buffer() != nullptr) {
delete piece->sparse_indices();
}
});
delete root_piece_;
}
}

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@ -77,11 +76,6 @@ class LiteralBase {
template <typename NativeT>
absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
// Returns a const pointer to the sparse index array. Returns nullptr if the
// literal is not a sparse array.
const SparseIndexArray* sparse_indices(
const ShapeIndex& shape_index = {}) const;
// Returns a const pointer to (or size of) the underlying buffer holding the
// array at the given shape index. CHECKs if the subshape of the literal at
// the given ShapeIndex is not array.
@ -126,10 +120,6 @@ class LiteralBase {
// into text.
string GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index = {}) const;
// As GetSparseElement(), but determines the correct type and converts the
// value into text.
string GetSparseElementAsString(int64 sparse_element_number,
const ShapeIndex& shape_index = {}) const;
// Return whether the value at the specified index is equal to the provided
// generic `value` (T must be an arithmetic type).
@ -172,21 +162,6 @@ class LiteralBase {
absl::optional<complex128> GetAsComplex128(
absl::Span<const int64> multi_index) const;
// Returns the multi-index of the element in a sparse literal at the given
// sparse element number. The sparse element number is the position with in
// the sparse array's list of (index, value) pairs, and is checked against the
// total number of (index, value) pairs in the sparse array.
absl::Span<const int64> GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
// Returns the value of the element in a sparse literal at the given sparse
// element number. The sparse element number is the position with in the
// sparse array's list of (index, value) pairs, and is checked against the
// total number of (index, value) pairs in the sparse array.
template <typename NativeT>
NativeT GetSparseElement(int64 sparse_element_number,
const ShapeIndex& shape_index = {}) const;
// Invokes the "per cell" callback for each element in the provided
// literal with the element's indices and a string representation of
// the element's value.
@ -259,13 +234,7 @@ class LiteralBase {
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}
// Returns the count of the elements in the sparse array at the given shape
// index in this literal, which will be no larger than
// LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
int64 sparse_element_count() const;
// Compute a hash for this literal. This literal must not be a sparse tensor
// or a tuple containing a sparse tensor.
// Compute a hash for this literal.
size_t Hash() const;
// Converts this literal to the given shape. Returns an error is the
@ -385,14 +354,6 @@ class LiteralBase {
char* buffer() const { return buffer_; }
void set_buffer(char* buffer) { buffer_ = buffer; }
// The array of multi-indices that provide the locations of non-zero
// elements in a sparse array. Only used if
// LayoutUtil::IsSparseArray(shape()) is true.
SparseIndexArray* sparse_indices() const { return sparse_indices_; }
void set_sparse_indices(SparseIndexArray* sparse_indices) {
sparse_indices_ = sparse_indices;
}
// Gets or sets the subshape of this piece. This reference points to a
// subshape within the shape in the containing Literal (Literal::shape_).
const Shape& subshape() const { return *subshape_; }
@ -402,13 +363,7 @@ class LiteralBase {
int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
// Returns the number of elements in this piece's array.
int64 element_count() const {
// If this is a sparse array, use the number of elements represented by
// the indices in the associated SparseIndexArray.
return LayoutUtil::IsSparseArray(subshape())
? sparse_indices()->index_count()
: ShapeUtil::ElementsIn(subshape());
}
int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
// Returns the child piece at 'index' of this piece.
Piece& child(int64 index) { return children_[index]; }
@ -489,9 +444,6 @@ class LiteralBase {
// piece must be equal (not just compatible) to the shape of the proto.
Status CopyFromProto(const LiteralProto& proto);
// Sorts the elements in a sparse array.
void SortSparseElements();
private:
// Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
// The first non-OK (or non-true) value is returned by the function.
@ -541,17 +493,9 @@ class LiteralBase {
bool EqualElementsInternal(const Piece& other,
std::vector<int64>* multi_index) const;
// Helper for SortSparseElements that has the element type as a template
// parameter.
template <typename NativeT>
void SortSparseElementsInternal();
// For array-shaped pieces, this is the buffer holding the literal data.
char* buffer_ = nullptr;
// For sparse arrays, this is the array of indices.
SparseIndexArray* sparse_indices_ = nullptr;
// The shape of piece. This points into the shape of the containing Literal
// (Literal::shape_).
const Shape* subshape_ = nullptr;
@ -598,10 +542,6 @@ class MutableLiteralBase : public LiteralBase {
// Unhide const method from parent class.
using LiteralBase::data;
// Returns a pointer to the sparse index array. Returns nullptr if the literal
// is not a sparse array.
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
// mutate the shape as this can produce malformed Literals.
Shape* mutable_shape_do_not_use() { return shape_.get(); }
@ -613,16 +553,6 @@ class MutableLiteralBase : public LiteralBase {
// Unhide const method from parent class.
using LiteralBase::untyped_data;
// Populates a literal with a sparse layout with the given indices and values.
// Each index in the indices array is CHECKed against the dimensions in the
// literal's shape. If sort is true, then the indices and values will be
// sorted. If sort is false, then the indices and values are assumed to
// already be in sorted order. See CreateSparse for an example of how data
// are populated.
template <typename NativeT>
void PopulateSparse(SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort = true);
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
// at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
@ -661,16 +591,6 @@ class MutableLiteralBase : public LiteralBase {
template <typename NativeT>
void Set(absl::Span<const int64> multi_index, NativeT value);
// Appends the given element to the literal. If the elements are not appended
// in sorted order, then SortSparseElements should be called before calling
// other methods. This literal must have a sparse layout.
template <typename NativeT>
void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
const ShapeIndex& shape_index = {});
// Sorts the elements in a sparse array.
void SortSparseElements(const ShapeIndex& shape_index = {});
// As Set(), but truncates `value` to the literal element type before storing.
// This literal must be an array.
Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
@ -988,34 +908,6 @@ NativeT LiteralBase::GetFirstElement() const {
return data<NativeT>().at(0);
}
template <typename NativeT>
NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
const ShapeIndex& shape_index) const {
CHECK(
LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
return data<NativeT>(shape_index)[sparse_element_number];
}
template <typename NativeT>
void MutableLiteralBase::AppendSparseElement(
absl::Span<const int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
Piece& p = piece(shape_index);
const Shape& subshape = p.subshape();
CHECK(LayoutUtil::IsSparseArray(subshape));
int64 rank = subshape.rank();
CHECK_EQ(multi_index.size(), rank);
for (int64 i = 0; i < rank; ++i) {
CHECK_GE(multi_index[i], 0);
CHECK_LT(multi_index[i], subshape.dimensions(i));
}
int64 last_element = p.sparse_indices()->index_count();
CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
p.sparse_indices()->Append(multi_index);
CHECK_LT(last_element, p.data<NativeT>().size());
p.data<NativeT>()[last_element] = value;
}
template <typename NativeT>
void LiteralBase::EachCell(
std::function<void(absl::Span<const int64> indices, NativeT value)>
@ -1094,31 +986,6 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
absl::Span<const NativeT> values,
bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = shape().rank();
CHECK_EQ(indices.rank(), rank);
int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
CHECK_LE(indices.max_indices(), max_elements);
int64 num_elements = values.size();
CHECK_LE(num_elements, max_elements);
CHECK_EQ(num_elements, indices.index_count());
auto root_data = root_piece().data<NativeT>();
// Piece::data() returns a Span of size equal to the number of indices
// in the SparseIndexArray. So there is no need to adjust the size of the data
// here. It is enough to just copy the incoming values into the data buffer.
std::copy(values.begin(), values.end(), root_data.begin());
*this->root_piece().sparse_indices() = std::move(indices);
if (sort) {
auto root_data = this->root_piece().data<NativeT>();
this->root_piece().sparse_indices()->SortWithValues(root_data);
}
DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
}
template <typename NativeT, typename FnType>
Status MutableLiteralBase::PopulateInternal(const FnType& generator,
bool parallel) {

View File

@ -252,42 +252,6 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, CreateSparse) {
std::vector<int64> dimensions = {8, 8, 8};
Array2D<int64> indices = {
{3, 4, 5},
{1, 2, 3},
{2, 3, 4},
{3, 5, 6},
};
std::vector<int64> values = {7, 8, 9, 10};
auto literal = LiteralUtil::CreateSparse<int64>(
dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
Array2D<int64> expected_indices = {
{1, 2, 3},
{2, 3, 4},
{3, 4, 5},
{3, 5, 6},
};
std::vector<int64> expected_values = {8, 9, 7, 10};
EXPECT_EQ(literal.sparse_indices()->data(),
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
// Serialize then deserialize and verify the resulting literal.
TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
Literal::CreateFromProto(literal.ToProto()));
EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
EXPECT_EQ(literal_from_proto.data<int64>(),
absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
// clang-format off
auto literal = LiteralUtil::CreateR4Projected<float>({
@ -1978,43 +1942,6 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
}
TEST_F(LiteralUtilTest, SortSparseElements) {
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
SparseIndexArray(10, 3), {});
literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
literal.SortSparseElements();
EXPECT_EQ(literal.ToString(),
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
}
TEST_F(LiteralUtilTest, GetSparseElementAsString) {
std::vector<int64> dimensions = {10, 10, 10};
SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
EXPECT_EQ(
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
.GetSparseElementAsString(1),
"false");
EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
.GetSparseElementAsString(1),
absl::StrCat(int64{2}));
EXPECT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
.GetSparseElementAsString(1),
absl::StrCat(double{2.0}));
EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
.GetSparseElementAsString(1),
absl::StrCat(static_cast<float>(half{2.0})));
EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
.GetSparseElementAsString(1),
absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(

View File

@ -38,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@ -102,46 +101,6 @@ class LiteralUtil {
values,
const Layout& layout);
// Creates a literal with a sparse layout and the given indices and values.
// The shape is initialized from the given dimensions. The minor dimension of
// the indices array must equal the rank of the shape (i.e. size of the
// dimensions array). The major dimension of the indices array must equal the
// number of elements in the values array. The maximum number of elements in
// the array is taken from the max_indices() value of the index array.
//
// XLA assumes that sparse literals are in sorted order for all operations. If
// the `sort` argument is true, then the indices and values will be sorted
// while copying them into the literal. If you have ensured that the indices
// and values are already sorted, then you may set the `sort` argument to
// false to skip the sorting step.
//
// For example:
//
// CreateSparse(
// {12, 12, 12},
// SparseIndexArray(10, 3,
// Array2D{
// {0, 1, 2},
// {3, 4, 5},
// {6, 7, 8},
// {9, 10, 11},
// }),
// {1.0, 2.0 3.0, 4.0})
//
// This creates an array with shape F64[12,12,12]sparse{10}, that has the
// following non-zero values:
//
// [0, 1, 2]: 1.0
// [3, 4, 5]: 2.0
// [6, 7, 8]: 3.0
// [9, 10, 11]: 4.0
//
template <typename NativeT>
static Literal CreateSparse(absl::Span<const int64> dimensions,
SparseIndexArray indices,
absl::Span<const NativeT> values,
bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
// Creates a scalar literal value one of the given primitive type.
@ -417,21 +376,6 @@ template <typename NativeT>
return CreateR4FromArray4DWithLayout(tmp, layout);
}
template <typename NativeT>
/* static */ Literal LiteralUtil::CreateSparse(
absl::Span<const int64> dimensions, SparseIndexArray indices,
absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
indices.max_indices()));
literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<

View File

@ -77,7 +77,6 @@ cc_library(
":buffer_info_util",
":conv_canonicalization",
":cpu_executable",
":cpu_hlo_support_checker",
":cpu_instruction_fusion",
":cpu_layout_assignment",
":cpu_options",
@ -960,32 +959,6 @@ cc_library(
],
)
cc_library(
name = "cpu_hlo_support_checker",
srcs = ["cpu_hlo_support_checker.cc"],
hdrs = ["cpu_hlo_support_checker.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "cpu_hlo_support_checker_test",
srcs = ["cpu_hlo_support_checker_test.cc"],
deps = [
":cpu_hlo_support_checker",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "cpu_eigen_tensor_alignment_test",
size = "small",

View File

@ -60,7 +60,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
@ -248,7 +247,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<ZeroSizedHloElimination>();
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<CpuHloSupportChecker>();
pipeline.AddPass<ConditionalToSelect>();
pipeline.AddPass<MapInliner>();

View File

@ -1,46 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) {
for (auto* computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
instruction->shape(),
[&instruction](const Shape& subshape, const ShapeIndex&) {
if (LayoutUtil::IsSparseArray(subshape)) {
return xla::Unimplemented(
"CPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
instruction->ToString(),
ShapeUtil::HumanStringWithLayout(instruction->shape()));
}
return Status::OK();
}));
}
}
return false;
}
} // namespace xla

View File

@ -1,40 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
class CpuHloSupportChecker : public HloModulePass {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
absl::string_view name() const override { return "cpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_

View File

@ -1,76 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace xla {
namespace {
using ::testing::HasSubstr;
class CpuHloSupportCheckerTest : public HloTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
private:
CpuHloSupportChecker checker_;
};
TEST_F(CpuHloSupportCheckerTest, Add) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(checker().Run(module.get()).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
HloComputation::Builder builder(TestName());
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
sparse_shape, HloOpcode::kAdd, param0, param1));
// Since verifier is reporting sparse layouts as errors, we should
// use a regular HloModule instead of VerifiedHloModule to avoid
// verifier errors being triggered in the destructor.
auto module = CreateNewUnverifiedModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module.get()).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
EXPECT_THAT(status.error_message(),
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
}
} // namespace
} // namespace xla

View File

@ -116,29 +116,6 @@ non_tuple
| rank2345
;
rank2345
: shape sparse_or_nested_array
: nested_array
;
sparse_or_nested_array
: sparse_array
| nested_array
;
sparse_array
: '{' sparse_array1 '}'
;
sparse_array1
: sparse_array_item
| sparse_array1 ',' sparse_array_item
;
sparse_array_item
: multi_index ':' scalar
;
multi_index
: kInt
| '[' multi_index1 ']'
;
multi_index1
: kInt
| multi_index1 ',' kInt
;
```

View File

@ -1093,7 +1093,6 @@ cc_library(
":gpu_copy_insertion",
":gpu_executable",
":gpu_hlo_schedule",
":gpu_hlo_support_checker",
":gpu_layout_assignment",
":gpu_sanitize_constant_names",
":gpu_scatter_expander",
@ -1416,18 +1415,6 @@ tf_cc_test(
],
)
cc_library(
name = "gpu_hlo_support_checker",
srcs = ["gpu_hlo_support_checker.cc"],
hdrs = ["gpu_hlo_support_checker.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
cc_library(
name = "stream_executor_util",
srcs = ["stream_executor_util.cc"],
@ -1455,20 +1442,6 @@ cc_library(
],
)
tf_cc_test(
name = "gpu_hlo_support_checker_test",
srcs = ["gpu_hlo_support_checker_test.cc"],
deps = [
":gpu_hlo_support_checker",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
],
)
cc_library(
name = "buffer_comparator",
srcs = ["buffer_comparator.cc"],

View File

@ -49,7 +49,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
@ -135,7 +134,6 @@ Status GpuCompiler::OptimizeHloModule(
pipeline.AddPass<GpuScatterExpander>();
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<GpuHloSupportChecker>();
// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();

View File

@ -1,46 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) {
for (auto* computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
instruction->shape(),
[&instruction](const Shape& subshape, const ShapeIndex&) {
if (LayoutUtil::IsSparseArray(subshape)) {
return xla::Unimplemented(
"GPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
instruction->ToString(),
ShapeUtil::HumanStringWithLayout(instruction->shape()));
}
return Status::OK();
}));
}
}
return false;
}
} // namespace xla

View File

@ -1,40 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
class GpuHloSupportChecker : public HloModulePass {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
absl::string_view name() const override { return "gpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_

View File

@ -1,76 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace xla {
namespace {
using ::testing::HasSubstr;
class GpuHloSupportCheckerTest : public HloTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
private:
GpuHloSupportChecker checker_;
};
TEST_F(GpuHloSupportCheckerTest, Add) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(checker().Run(module.get()).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
HloComputation::Builder builder(TestName());
const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2);
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, sparse_shape, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
sparse_shape, HloOpcode::kAdd, param0, param1));
// Since verifier is reporting sparse layouts as errors, we should
// use a regular HloModule instead of VerifiedHloModule to avoid
// verifier errors being triggered in the destructor.
auto module = CreateNewUnverifiedModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module.get()).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
EXPECT_THAT(status.error_message(),
HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape)));
}
} // namespace
} // namespace xla

View File

@ -280,7 +280,6 @@ TokKind HloLexer::LexIdentifier() {
KEYWORD(ROOT);
KEYWORD(maximal);
KEYWORD(replicated);
KEYWORD(sparse);
#undef KEYWORD
@ -496,8 +495,6 @@ string TokKindToString(TokKind kind) {
return "kw_inf";
case TokKind::kNegInf:
return "kNegInf";
case TokKind::kw_sparse:
return "kw_sparse";
case TokKind::kPrimitiveType:
return "kPrimitiveType";
case TokKind::kName:

View File

@ -63,7 +63,6 @@ enum class TokKind {
kw_replicated,
kw_nan,
kw_inf,
kw_sparse,
kNegInf, // -inf

View File

@ -72,10 +72,6 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
return schedule;
}
// Some functions accept either a linear index or a multi-dimensional index
// (used for indexing into sparse literals).
using LinearOrMultiIndex = absl::variant<int64, absl::Span<const int64>>;
// Parser for the HloModule::ToString() format text.
class HloParserImpl : public HloParser {
public:
@ -137,24 +133,21 @@ class HloParserImpl : public HloParser {
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
bool ParseDenseLiteral(Literal* literal, const Shape& shape);
bool ParseSparseLiteral(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given linear or sparse index to the
// given value. If the literal is dense, it myst have the default layout.
// Sets the sub-value of literal at the given linear index to the
// given value. If the literal is dense, it must have the default layout.
//
// `loc` should be the source location of the value.
bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index,
bool SetValueInLiteral(LocTy loc, int64 value, int64 index, Literal* literal);
bool SetValueInLiteral(LocTy loc, double value, int64 index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index,
bool SetValueInLiteral(LocTy loc, bool value, int64 index, Literal* literal);
bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64 index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, std::complex<double> value,
LinearOrMultiIndex index, Literal* literal);
// `loc` should be the source location of the value.
template <typename LiteralNativeT, typename ParsedElemT>
bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
LinearOrMultiIndex index, Literal* literal);
bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64 index,
Literal* literal);
// Checks whether the given value is within the range of LiteralNativeT.
// `loc` should be the source location of the value.
@ -2125,8 +2118,7 @@ bool HloParserImpl::ParseInstructionNames(
"expects '}' at the end of instruction name list");
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value,
LinearOrMultiIndex index,
bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, int64 index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
@ -2160,8 +2152,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value,
}
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, double value,
LinearOrMultiIndex index,
bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64 index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
@ -2180,8 +2171,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, double value,
}
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value,
LinearOrMultiIndex index,
bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64 index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
@ -2194,8 +2184,7 @@ bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value,
}
bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
LinearOrMultiIndex index,
Literal* literal) {
int64 index, Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case C64:
@ -2221,54 +2210,21 @@ std::string StringifyValue(std::complex<double> val) {
template <typename LiteralNativeT, typename ParsedElemT>
bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
LinearOrMultiIndex index,
Literal* literal) {
int64 index, Literal* literal) {
if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
return false;
}
// Check that the index is in range and assign into the literal
if (auto* linear_index = absl::get_if<int64>(&index)) {
if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
return Error(loc, StrCat("trys to set value ", StringifyValue(value),
" to a literal in shape ",
ShapeUtil::HumanString(literal->shape()),
" at linear index ", *linear_index,
", but the index is out of range"));
}
literal->data<LiteralNativeT>().at(*linear_index) =
static_cast<LiteralNativeT>(value);
} else {
auto* multi_index = absl::get_if<absl::Span<const int64>>(&index);
CHECK(multi_index != nullptr);
auto invalid_idx = [&](std::string msg) {
return Error(loc, StrFormat("Invalid sparse index [%s]. %s",
absl::StrJoin(*multi_index, ", "), msg));
};
const auto& shape = literal->shape();
if (shape.rank() != multi_index->size()) {
return invalid_idx(
StrFormat("Has rank %d, but constant has shape %s, which has rank %d",
multi_index->size(), shape.ToString(), shape.rank()));
}
for (int64 i = 0; i < shape.rank(); ++i) {
auto idx = (*multi_index)[i];
if (idx < 0) {
return invalid_idx(StrFormat(
"Sub-index value at %d, namely %d, cannot be negative.", i, idx));
}
if (idx >= shape.dimensions(i)) {
return invalid_idx(
StrFormat("Sub-index at %d, namely %d, doesn't fit within shape "
"dimension %d in %s",
i, idx, shape.dimensions(i), shape.ToString()));
}
}
literal->AppendSparseElement(*multi_index,
static_cast<LiteralNativeT>(value));
if (index >= ShapeUtil::ElementsIn(literal->shape())) {
return Error(loc, StrCat("trys to set value ", StringifyValue(value),
" to a literal in shape ",
ShapeUtil::HumanString(literal->shape()),
" at linear index ", index,
", but the index is out of range"));
}
literal->data<LiteralNativeT>().at(index) =
static_cast<LiteralNativeT>(value);
return true;
}
@ -2314,12 +2270,8 @@ bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
// non_tuple
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
// rank2345 ::= shape nested_array
bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
return ParseDenseLiteral(literal, shape);
}
@ -2500,98 +2452,6 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
return true;
}
bool HloParserImpl::ParseSparseLiteral(Literal* literal, const Shape& shape) {
*literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
return false;
}
for (;;) {
if (lexer_.GetKind() == TokKind::kRbrace) {
lexer_.Lex();
break;
}
std::vector<int64> index;
if (lexer_.GetKind() == TokKind::kInt) {
int64 single_index = lexer_.GetInt64Val();
lexer_.Lex();
index.push_back(single_index);
} else {
if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
&index)) {
return false;
}
}
if (!ParseToken(TokKind::kColon,
"expects ':' after after the sparse array index and before "
"the sparse array value")) {
return false;
}
LocTy value_loc = lexer_.GetLoc();
if (lexer_.GetKind() == TokKind::kw_true ||
lexer_.GetKind() == TokKind::kw_false) {
bool value = lexer_.GetKind() == TokKind::kw_true;
if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) {
return false;
}
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
int64 value;
if (!ParseInt64(&value)) {
return Error(value_loc,
StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value_loc, value, index, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
double value;
if (!ParseDouble(&value)) {
return Error(value_loc,
StrCat("expects floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value_loc, value, index, literal)) {
return false;
}
} else if (primitive_util::IsComplexType(shape.element_type())) {
std::complex<double> value;
if (!ParseComplex(&value)) {
return Error(value_loc,
StrCat("expects complex value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value_loc, value, index, literal)) {
return false;
}
} else {
LOG(FATAL) << "Unexpected element type: "
<< PrimitiveType_Name(shape.element_type());
}
if (lexer_.GetKind() != TokKind::kRbrace &&
!ParseToken(TokKind::kComma,
"expects ',' separator between sparse array elements")) {
return false;
}
if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
StrCat("number of sparse elements exceeds maximum for layout: ",
ShapeUtil::HumanStringWithLayout(shape)));
}
}
literal->SortSparseElements();
return true;
}
// MaxFiniteValue is a type-traits helper used by
// HloParserImpl::CheckParsedValueIsInRange.
template <typename T>
@ -3839,21 +3699,6 @@ bool HloParserImpl::ParseShape(Shape* result) {
}
LayoutUtil::SetToDefaultLayout(result);
if (lexer_.GetKind() == TokKind::kw_sparse) {
lexer_.Lex();
const std::string message =
"expects a brace-bracketed integer for sparse layout";
int64 max_sparse_elements;
if (!ParseToken(TokKind::kLbrace, message) ||
!ParseInt64(&max_sparse_elements) ||
!ParseToken(TokKind::kRbrace, message)) {
return false;
}
*result->mutable_layout() =
LayoutUtil::MakeSparseLayout(max_sparse_elements);
return true;
}
// We need to lookahead to see if a following open brace is the start of a
// layout. The specific problematic case is:
//

View File

@ -841,50 +841,6 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
)"
},
{
"Sparse",
R"(HloModule sparse_f32
ENTRY %sparse () -> f32[2,3,4] {
ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
}
)",
/*enable_verification=*/false
},
{
"SparseC128",
R"(HloModule sparse_c128
ENTRY %sparse () -> c128[2,3,4] {
ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
}
)",
/*enable_verification=*/false
},
{
"SparseEmpty",
R"(HloModule sparse_f32_empty
ENTRY %sparse_f32_empty () -> f32[2,3,4] {
ROOT %foo = f32[2,3,4]sparse{10} constant({})
}
)",
/*enable_verification=*/false,
},
{
"SparseR1",
R"(HloModule sparse_f32_r1
ENTRY %sparse_f32_r1 () -> f32[9] {
ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
}
)",
/*enable_verification=*/false,
},
{
"Gather",
R"(HloModule StringifyGather
@ -1982,17 +1938,6 @@ TEST_F(HloParserTest, ConstantBf16Overflow) {
"out of range");
}
TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65520})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"is out of range for literal's primitive type F16");
}
TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
const string original = R"(
HloModule ConstantUnsignedUnderflow_module
@ -2852,50 +2797,6 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
" with the shape of the operand instruction f32[2,2]{1,0}.");
}
TEST_F(HloParserTest, OutOfRangeSparseIndex) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f16[5]sparse{10} constant({[100]: 0})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Invalid sparse index");
}
TEST_F(HloParserTest, NegativeSparseIndex) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f16[5]sparse{10} constant({-1: 0})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Invalid sparse index");
}
TEST_F(HloParserTest, SparseIndexWithRankTooLarge) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f16[5]sparse{10} constant({[0, 0]: 0})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Invalid sparse index");
}
TEST_F(HloParserTest, SparseIndexWithRankTooSmall) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f16[5, 5]sparse{10} constant({[0]: 0})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Invalid sparse index");
}
TEST_F(HloParserTest, ParseShapeStringR2F32) {
string shape_string = "f32[123,456]";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
@ -2994,15 +2895,6 @@ TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
"Dimensions size is 3, but minor to major size is 1.");
}
TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) {
string shape_string = "f32[123,456]sparse{10}";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
// Tile, element size, and memory space.
string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
@ -3047,10 +2939,8 @@ TEST_F(HloParserTest, ParseTokenType) {
}
TEST_F(HloParserTest, ParseInvalidShapeString) {
string shape_strings[] = {
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
"f32[123,456]dense{foo}", "f32[123,456]sparse{foo}",
};
string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
"f32[123,456]dense{foo}"};
for (const string& shape_string : shape_strings) {
StatusOr<Shape> result = ParseShape(shape_string);
ASSERT_FALSE(result.ok()) << "shape: " << shape_string;

View File

@ -33,17 +33,6 @@ limitations under the License.
namespace xla {
Status VerifyNotSparse(const Shape& shape) {
return ShapeUtil::ForEachSubshapeWithStatus(
shape, [](const Shape& subshape, const ShapeIndex&) -> Status {
if (LayoutUtil::IsSparseArray(subshape)) {
return InternalError("Sparse arrays are not yet fully supported: %s",
ShapeUtil::HumanStringWithLayout(subshape));
}
return Status::OK();
});
}
bool IsCallerInstruction(HloInstruction* hlo) {
switch (hlo->opcode()) {
case HloOpcode::kCall:
@ -93,8 +82,6 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
"Called computations specified for non-caller instruction %s",
hlo->ToString());
}
TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape()));
absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
if (arity) {
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
@ -1109,8 +1096,6 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape()));
if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
result_layout.shape())) {
return InternalError(
@ -1131,7 +1116,6 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
const HloInstruction* parameter = computation->parameter_instruction(i);
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i)));
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
return InternalError(
"Shape of the entry computation parameter %d is %s should be "

View File

@ -73,7 +73,7 @@ namespace xla {
// - EqualTo
// - CompatibleTo
// - IsScalar/IsEffectiveScalar/IsArray/IsTuple
// - IsDenseArray/IsSparseArray
// - IsDenseArray
// - WithLayout: layout shape's layout matches the given pattern (e.g.
// Layout().WithDenseFormat())
// - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
@ -87,7 +87,7 @@ namespace xla {
//
// Layout():
// - EqualTo
// - WithDenseFormat/WithSparseFormat
// - WithDenseFormat
//
// Op(), Shape(), and Layout() may be passed an argument of type
// HloInstruction**, Shape**, or Layout**, respectively, or const versions of
@ -506,12 +506,6 @@ class LayoutPattern {
return AppendImpl(LayoutPatternFormatImpl(DENSE));
}
// Modifies the pattern to match only if the layout has a sparse format.
constexpr auto WithSparseFormat() const
-> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
return AppendImpl(LayoutPatternFormatImpl(SPARSE));
}
private:
Impl impl_;
LayoutType** matched_layout_;
@ -1060,11 +1054,6 @@ class ShapePattern {
return WithLayout(Layout().WithDenseFormat());
}
constexpr auto IsSparseArray() const
-> decltype(this->WithLayout(Layout().WithSparseFormat())) {
return WithLayout(Layout().WithSparseFormat());
}
// Modifies the pattern to match only if the shape has a subshape that matches
// the given pattern.
template <typename SubshapeType, typename SubshapeImpl>

View File

@ -56,9 +56,6 @@ TEST(PatternMatcherGmock, MatchShape) {
TEST(PatternMatcherGmock, MatchLayout) {
Layout l = LayoutUtil::MakeLayout({0, 1});
EXPECT_THAT(l, GmockMatch(m::Layout()));
EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat())));
EXPECT_THAT(Describe<Layout>(GmockMatch(m::Layout().WithSparseFormat())),
"a layout with format SPARSE");
}
TEST(PatternMatchGmock, MatchInstruction) {

View File

@ -89,7 +89,6 @@ TEST_F(PatternMatcherTest, DenseArrayShape) {
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
EXPECT_EQ(matched_shape, &array_shape);
EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray()));
EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
@ -97,38 +96,12 @@ TEST_F(PatternMatcherTest, DenseArrayShape) {
EXPECT_FALSE(
Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
Layout* matched_layout;
EXPECT_FALSE(Match(&array_shape,
match::Shape().WithLayout(
match::Layout(&matched_layout).WithSparseFormat())));
EXPECT_TRUE(Match(&array_shape,
match::Shape().WithLayout(
match::Layout(&matched_layout).WithDenseFormat())));
EXPECT_EQ(matched_layout, &array_shape.layout());
}
TEST_F(PatternMatcherTest, SparseArrayShape) {
auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10);
Shape* matched_shape;
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
EXPECT_EQ(matched_shape, &array_shape);
EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray()));
EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray()));
EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
EXPECT_FALSE(
Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
Layout* matched_layout;
EXPECT_FALSE(Match(&array_shape,
match::Shape().WithLayout(
match::Layout(&matched_layout).WithDenseFormat())));
EXPECT_TRUE(Match(&array_shape,
match::Shape().WithLayout(
match::Layout(&matched_layout).WithSparseFormat())));
EXPECT_EQ(matched_layout, &array_shape.layout());
}
TEST_F(PatternMatcherTest, TupleShape) {
auto tuple_shape = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1, 2, 3}),
@ -568,15 +541,6 @@ TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
"a layout equal to {1,2}",
"Layout {2,2} is not equal to expected {1,2}");
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(),
"a layout with format SPARSE",
"Layout has format DENSE but expected SPARSE");
EXPECT_DESC_AND_EXPLANATION(layout,
m::Layout().EqualTo(&layout).WithSparseFormat(),
"a layout:\n"
" * equal to {1,2} AND\n"
" * with format SPARSE",
"Layout has format DENSE but expected SPARSE");
}
TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
@ -665,11 +629,6 @@ TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
"a shape with\n a layout equal to {0,1}",
"Layout {1,0} is not equal to expected {0,1}\n"
"in f32[1,2]{1,0}");
EXPECT_DESC_AND_EXPLANATION(
shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()),
"a shape with\n a layout with format SPARSE",
"Layout has format DENSE but expected SPARSE\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape,
m::Shape().WithSubshapeEqualTo({10}, &shape),
"a shape with subshape at index {10} which is\n"

View File

@ -229,16 +229,6 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return MakeShapeWithLayout(element_type, dimensions, layout);
}
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
PrimitiveType element_type, absl::Span<const int64> dimensions,
int64 max_sparse_elements) {
CHECK(IsArrayPrimitiveType(element_type));
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
return shape;
}
/* static */ Shape
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape) {
@ -637,9 +627,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return ByteSizeOfTupleIndexTable(shape, pointer_size);
} else if (shape.IsArray()) {
int64 byte_size = ByteSizeOfElements(shape);
if (LayoutUtil::IsSparseArray(shape)) {
byte_size += ByteSizeOfSparseIndices(shape);
}
return byte_size;
} else if (shape.element_type() == TOKEN) {
return 0;
@ -664,23 +651,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
CHECK(shape.IsArray());
int64 allocated_element_count;
if (LayoutUtil::IsSparseArray(shape)) {
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
} else {
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
allocated_element_count = ElementsIn(shape);
}
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
allocated_element_count = ElementsIn(shape);
return allocated_element_count *
ByteSizeOfPrimitiveType(shape.element_type());
}
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
CHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() *
sizeof(int64);
}
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
@ -721,9 +697,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return Status::OK();
}
if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) {
return InvalidArgument("sparse arrays must have rank > 0");
}
for (int64 i = 0; i < shape.rank(); ++i) {
int64 dimension = shape.dimensions(i);
if (dimension < 0) {
@ -744,43 +717,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return Status::OK();
}
// We can only reason about some aspects of array's shape if it has a valid
// layout, these aspects will be ignored otherwise.
bool shape_has_valid_layout = LayoutUtil::HasLayout(shape) &&
LayoutUtil::ValidateLayoutInShape(shape).ok();
int64 shape_size = [&]() {
if (shape_has_valid_layout && LayoutUtil::IsSparseArray(shape)) {
int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout());
if (max_sparse_elements < 0) {
return max_sparse_elements;
}
int64 sparse_elements_size = MultiplyWithoutOverflow(
max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type()));
if (sparse_elements_size < 0) {
return sparse_elements_size;
}
int64 sparse_indices_size =
MultiplyWithoutOverflow(max_sparse_elements, shape.rank());
if (sparse_indices_size < 0) {
return sparse_indices_size;
}
sparse_indices_size =
MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64));
if (sparse_indices_size < 0) {
return sparse_indices_size;
}
// At this point, both sparse_indices_size and sparse_elements_size are
// non-negative, so we can easily check if adding them wraps.
if (static_cast<uint64>(sparse_elements_size) +
static_cast<uint64>(sparse_indices_size) >
INT64_MAX) {
return static_cast<int64>(-1);
}
}
// This is intentionally unconditional: even if the shape is sparse, we want
// to verify the densified version has a reasonable size.
int64 dense_shape_size = 1;
if (shape.dimensions().empty()) {
return dense_shape_size;

View File

@ -192,10 +192,7 @@ class ShapeUtil {
};
// Returns the number of elements are contained within the provided shape;
// e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
// may not actually be able to store this number of elements. See
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
// elements that can be stored in a sparse shape.
// e.g. for rank 0 (scalars) the result is always 1.
// Precondition: shape.IsArray()
static int64 ElementsIn(const Shape& shape);
@ -228,20 +225,12 @@ class ShapeUtil {
int64 pointer_size);
// Returns the number of bytes required for the elements in an allocation of
// `shape`, which must be an array shape. The return value does not include
// the bytes needed to store sparse indices. Dense shapes use a separate
// `shape`, which must be an array shape. Shapes use a separate
// memory location for each element, and so for these shapes,
// `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
// size also includes padding if present in the layout. For sparse shapes,
// `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
// ByteSizeOfSparseindices(shape)`.
// `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. This
// size also includes padding if present in the layout.
static int64 ByteSizeOfElements(const Shape& shape);
// Returns the number of bytes required for the sparse indices in an
// allocation of shape. The shape must be an array shape. The return value
// does not include the bytes needed to store sparse indices.
static int64 ByteSizeOfSparseIndices(const Shape& shape);
// Returns a human-readable string that represents the given shape, with or
// without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
static string HumanString(const Shape& shape);
@ -427,9 +416,6 @@ class ShapeUtil {
int64 element_size_in_bits = 0,
int64 memory_space = 0);
static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
absl::Span<const int64> dimensions,
int64 max_sparse_elements);
// Returns the same shape except with all dimensions set to be static.
static Shape MakeShapeWithStaticDimensions(const Shape& shape);

View File

@ -1,109 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {}
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
std::vector<int64> indices)
: indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) {
CHECK_GT(rank_, 0);
CHECK_EQ(indices_.size() % rank_, 0)
<< "indices_.size(): " << indices_.size() << ", rank_: " << rank_;
CHECK_LE(index_count(), max_indices_);
}
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
absl::Span<const int64> indices)
: SparseIndexArray(max_indices, rank,
std::vector<int64>(indices.begin(), indices.end())) {}
SparseIndexArray::SparseIndexArray(int64 max_indices,
const Array2D<int64>& indices)
: SparseIndexArray(max_indices, indices.n2(),
std::vector<int64>(indices.begin(), indices.end())) {}
int64 SparseIndexArray::index_count() const {
CHECK_GT(rank_, 0);
CHECK_EQ(indices_.size() % rank_, 0);
return indices_.size() / rank_;
}
absl::Span<const int64> SparseIndexArray::At(
int64 sparse_element_number) const {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_element_number, 0);
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
return absl::Span<const int64>(
indices_.data() + rank_ * sparse_element_number, rank_);
}
absl::Span<int64> SparseIndexArray::At(int64 sparse_element_number) {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_element_number, 0);
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
return absl::Span<int64>(indices_.data() + rank_ * sparse_element_number,
rank_);
}
void SparseIndexArray::Append(absl::Span<const int64> index) {
CHECK_GT(rank_, 0);
CHECK_EQ(index.size(), rank_);
indices_.insert(indices_.end(), index.begin(), index.end());
}
void SparseIndexArray::Clear() { indices_.clear(); }
void SparseIndexArray::Resize(int64 num_indices) {
CHECK_GT(rank_, 0);
indices_.resize(rank_ * num_indices);
}
bool SparseIndexArray::Validate(const Shape& shape) const {
if (rank_ == 0 || rank_ != shape.rank()) {
return false;
}
int64 num_indices = index_count();
if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) {
return false;
}
if (num_indices < 2) {
return true;
}
absl::Span<const int64> last = At(0);
if (!IndexUtil::IndexInBounds(shape, last)) {
return false;
}
for (int64 n = 1; n < num_indices; ++n) {
absl::Span<const int64> next = At(n);
if (!IndexUtil::IndexInBounds(shape, next)) {
return false;
}
if (IndexUtil::CompareIndices(last, next) >= 0) {
return false;
}
last = next;
}
return true;
}
} // namespace xla

View File

@ -1,176 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility class for managing sparse array indices.
#ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
#define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
#include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// Encapsulates the array of indices for a sparse array. A SparseIndexArray
// contain indices for up to `max_indices` elements of a sparse array. Each
// sparse index is an array of `rank` int64 value that gives the location of a
// value within a sparse array. Note that the dimensions of the array are not
// checked (except for the rank). To avoid confusion, we refer to the position
// of an index within a SparseIndexArray as a sparse index number.
class SparseIndexArray {
public:
SparseIndexArray();
SparseIndexArray(const SparseIndexArray&) = default;
SparseIndexArray(SparseIndexArray&&) = default;
SparseIndexArray& operator=(const SparseIndexArray&) = default;
SparseIndexArray& operator=(SparseIndexArray&&) = default;
// Constructs a SparseIndexArray that can hold up to `max_indices` sparse
// indices, with an initial contents obtained from the given array. The rank
// is taken from the minor dimension of the array. The major dimension of the
// array must not exceed `max_indices`.
SparseIndexArray(int64 max_indices, const Array2D<int64>& indices);
// Like above, but the array is flattened. For example, the following are
// equivalent:
//
// SparseIndexArray(10, 3,
// Array2D{
// {0, 1, 2},
// {3, 4, 5},
// {6, 7, 8},
// {9, 10, 11},
// })
//
// SparseIndexArray(10, 3,
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
//
SparseIndexArray(int64 max_indices, int64 rank,
std::vector<int64> indices = {});
SparseIndexArray(int64 max_indices, int64 rank,
absl::Span<const int64> indices);
// Returns the number of elements represented by the indices stored in the
// array.
int64 index_count() const;
// Returns a slice that refers to the given sparse index number. The argument
// must be in the range [0, element_count()).
absl::Span<const int64> At(int64 sparse_element_number) const;
absl::Span<int64> At(int64 sparse_element_number);
// Adds the given index at the end of the array. The new size of the
// SparseIndexArray must not exceed `max_indices`.
void Append(absl::Span<const int64> index);
// Removes all indices from the array.
void Clear();
// Resizes the array to contain the given number of sparse indices. The new
// size must be smaller than `max_indices`. If the new size is larger than
// the old size, the value of the new indices is not specified.
void Resize(int64 num_indices);
// Returns true iff all indices are unique and occur in sorted order, and are
// valid for the given shape.
bool Validate(const Shape& shape) const;
int64 rank() const { return rank_; }
int64 max_indices() const { return max_indices_; }
// Returns a pointer to the int64 array that holds the sparse indices.
absl::Span<int64> mutable_data() { return absl::MakeSpan(indices_); }
absl::Span<const int64> data() const { return indices_; }
// Sorts this sparse index array along with the set of corresponding values.
// The indices and values are sorted in the lexicographic order of the
// indices, from smallest to largest.
//
// For example:
//
// std::vector<float> v{10.0, 11.0, 12.0};
// SparseIndexArray a(10, 3,
// {{3, 4, 5},
// {1, 2, 3},
// {2, 3, 4}});
// a.SortWithValues(&v);
// // Prints "11.0, 12.0, 10.0":
// std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
//
template <typename NativeT>
void SortWithValues(absl::Span<NativeT> values);
private:
std::vector<int64> indices_;
int64 rank_;
int64 max_indices_;
};
template <typename NativeT>
void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
int64 num_elements = index_count();
CHECK_EQ(values.size(), num_elements);
std::vector<int64> sort_order;
sort_order.reserve(num_elements);
for (int64 i = 0; i < num_elements; ++i) {
sort_order.push_back(i);
}
auto sort_order_less = [this](int64 lhs, int64 rhs) {
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
};
absl::c_sort(sort_order, sort_order_less);
// Reorder the array elements according to sort_order. Work through the array
// and follow cycles so we can do the reorder in-place.
absl::InlinedVector<int64, 8> saved_index(rank());
for (int64 i = 0; i < num_elements; ++i) {
// sort_order[i] == -1 indicates the element has already been copied.
if (sort_order[i] < 0) {
continue;
} else if (i == sort_order[i]) {
// The element is already in sorted order.
sort_order[i] = -1;
continue;
}
std::copy_n(At(i).begin(), rank(), saved_index.begin());
NativeT saved_value = values[i];
int64 j = i;
for (;;) {
if (sort_order[j] == i) {
std::copy_n(saved_index.begin(), rank(), At(j).begin());
values[j] = saved_value;
sort_order[j] = -1;
break;
}
std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin());
values[j] = values[sort_order[j]];
int64 k = sort_order[j];
sort_order[j] = -1;
j = k;
}
}
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_

View File

@ -1,43 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include <vector>
#include "tensorflow/compiler/xla/test.h"
namespace xla {
namespace {
TEST(SparseIndexArrayTest, Sort) {
SparseIndexArray a(10, 3);
a.Append({2, 3, 4});
a.Append({3, 4, 5});
a.Append({1, 2, 3});
a.Append({5, 6, 7});
a.Append({4, 5, 6});
a.Append({6, 7, 8});
std::vector<double> values = {
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
};
a.SortWithValues<double>(absl::MakeSpan(values));
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
6, 7, 6, 7, 8}));
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
}
} // namespace
} // namespace xla

View File

@ -115,9 +115,8 @@ enum Format {
INVALID_FORMAT = 0;
// The default layout, with exactly one storage location per element.
DENSE = 1;
// A sparsely encoded layout, providing only the index/value pairs of non-zero
// elements.
SPARSE = 2;
reserved 2;
reserved "SPARSE";
}
// Describes a tile used in tiling-based layout. Refer to
@ -156,10 +155,8 @@ message LayoutProto {
reserved 3;
reserved "padding_value";
// The maximum number of elements that can be stored for SPARSE formats. This
// can be used to determine the maximum size in bytes of arrays stored in
// memory. This field must be unset unless the format is SPARSE.
int64 max_sparse_elements = 5;
reserved 5;
reserved "max_sparse_elements";
// A sequence of tiles, starting from the tile that's applied first to the
// Shape.