[XLA] Delete StripDegenerateDimensions()
This is unused, and, as it turns out, is broken for sparse shapes. PiperOrigin-RevId: 200313641
This commit is contained in:
parent
213810a0d6
commit
bbc2c612da
@ -939,68 +939,6 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
|
|||||||
return leaves;
|
return leaves;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
|
|
||||||
CHECK(IsArray(shape));
|
|
||||||
|
|
||||||
std::vector<int64> dimension_sizes;
|
|
||||||
std::vector<int64> degenerate_dimensions;
|
|
||||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
|
||||||
if (shape.dimensions(i) == 1) {
|
|
||||||
degenerate_dimensions.push_back(i);
|
|
||||||
} else {
|
|
||||||
dimension_sizes.push_back(shape.dimensions(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct minor_to_major of stripped shape. The order of the non-degenerate
|
|
||||||
// dimensions should be preserved from the original shape. First, create
|
|
||||||
// vector of the non-degenerate dimensions from the original minor_to_major
|
|
||||||
// array.
|
|
||||||
std::vector<int64> minor_to_major;
|
|
||||||
for (int64 i : shape.layout().minor_to_major()) {
|
|
||||||
if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(),
|
|
||||||
i) == degenerate_dimensions.end()) {
|
|
||||||
minor_to_major.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The dimensions in minor_to_major need to be renumbered to account for the
|
|
||||||
// degenerate dimensions which have removed. Decrement each dimension number
|
|
||||||
// once for each degenerate dimension which has a smaller number.
|
|
||||||
for (int i = 0; i < minor_to_major.size(); ++i) {
|
|
||||||
int adjustment = 0;
|
|
||||||
for (int64 dim : degenerate_dimensions) {
|
|
||||||
if (minor_to_major[i] > dim) {
|
|
||||||
adjustment++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
minor_to_major[i] -= adjustment;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
std::vector<int64> dims(minor_to_major.size());
|
|
||||||
std::iota(dims.begin(), dims.end(), 0);
|
|
||||||
DCHECK(minor_to_major.size() == dims.size() &&
|
|
||||||
std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
|
|
||||||
dims.begin()));
|
|
||||||
}
|
|
||||||
Shape stripped_shape;
|
|
||||||
if (LayoutUtil::IsDenseArray(shape)) {
|
|
||||||
stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes,
|
|
||||||
minor_to_major);
|
|
||||||
} else if (LayoutUtil::IsSparseArray(shape)) {
|
|
||||||
stripped_shape =
|
|
||||||
MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes,
|
|
||||||
shape.layout().max_sparse_elements());
|
|
||||||
} else {
|
|
||||||
stripped_shape = MakeShape(shape.element_type(), dimension_sizes);
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
|
|
||||||
VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
|
|
||||||
return stripped_shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Helper for ForEachSubshape which visits the subshapes of the given shape in
|
// Helper for ForEachSubshape which visits the subshapes of the given shape in
|
||||||
|
@ -510,26 +510,6 @@ class ShapeUtil {
|
|||||||
static Status ForEachMutableSubshapeWithStatus(
|
static Status ForEachMutableSubshapeWithStatus(
|
||||||
Shape* shape, const MutatingStatusVisitorFunction& func);
|
Shape* shape, const MutatingStatusVisitorFunction& func);
|
||||||
|
|
||||||
// Removes all degenerate dimensions (size one) from the given shape. The
|
|
||||||
// stripped minor_to_major preserves the relative ordering of non-degenerate
|
|
||||||
// dimensions. The stripped shape has the property that the underlying
|
|
||||||
// representation (bits in memory) for the stripped shape is the same as the
|
|
||||||
// original shape modulo padding. Examples:
|
|
||||||
//
|
|
||||||
// input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2}
|
|
||||||
// stripped shape: F32 [2], minor_to_major = {0}
|
|
||||||
//
|
|
||||||
// input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1}
|
|
||||||
// stripped shape: F32 [6, 5], minor_to_major = {1, 0}
|
|
||||||
//
|
|
||||||
// input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1}
|
|
||||||
// stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1}
|
|
||||||
//
|
|
||||||
// input shape: F32 [1, 1], minor_to_major = {0, 1}
|
|
||||||
// stripped shape: F32 [], minor_to_major = {}
|
|
||||||
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
|
|
||||||
static Shape StripDegenerateDimensions(const Shape& shape);
|
|
||||||
|
|
||||||
// Permutes the dimensions by the given permutation, so
|
// Permutes the dimensions by the given permutation, so
|
||||||
// return_value.dimensions[permutation[i]] = argument.dimensions[i]
|
// return_value.dimensions[permutation[i]] = argument.dimensions[i]
|
||||||
static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
|
static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||||
|
@ -742,16 +742,6 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
|
|||||||
ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
|
ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeUtilTest, StripDegenerateDimensions) {
|
|
||||||
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions(
|
|
||||||
ShapeUtil::MakeShape(F32, {3, 1, 2})),
|
|
||||||
ShapeUtil::MakeShape(F32, {3, 2})));
|
|
||||||
EXPECT_TRUE(ShapeUtil::Equal(
|
|
||||||
ShapeUtil::StripDegenerateDimensions(
|
|
||||||
ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)),
|
|
||||||
ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
|
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
|
||||||
EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
|
EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
|
||||||
ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
|
ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
|
||||||
|
Loading…
Reference in New Issue
Block a user