[XLA] Delete StripDegenerateDimensions()
This is unused, and, as it turns out, is broken for sparse shapes. PiperOrigin-RevId: 200313641
This commit is contained in:
parent
213810a0d6
commit
bbc2c612da
@ -939,68 +939,6 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
|
||||
return leaves;
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
|
||||
CHECK(IsArray(shape));
|
||||
|
||||
std::vector<int64> dimension_sizes;
|
||||
std::vector<int64> degenerate_dimensions;
|
||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
||||
if (shape.dimensions(i) == 1) {
|
||||
degenerate_dimensions.push_back(i);
|
||||
} else {
|
||||
dimension_sizes.push_back(shape.dimensions(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Construct minor_to_major of stripped shape. The order of the non-degenerate
|
||||
// dimensions should be preserved from the original shape. First, create
|
||||
// vector of the non-degenerate dimensions from the original minor_to_major
|
||||
// array.
|
||||
std::vector<int64> minor_to_major;
|
||||
for (int64 i : shape.layout().minor_to_major()) {
|
||||
if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(),
|
||||
i) == degenerate_dimensions.end()) {
|
||||
minor_to_major.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// The dimensions in minor_to_major need to be renumbered to account for the
|
||||
// degenerate dimensions which have removed. Decrement each dimension number
|
||||
// once for each degenerate dimension which has a smaller number.
|
||||
for (int i = 0; i < minor_to_major.size(); ++i) {
|
||||
int adjustment = 0;
|
||||
for (int64 dim : degenerate_dimensions) {
|
||||
if (minor_to_major[i] > dim) {
|
||||
adjustment++;
|
||||
}
|
||||
}
|
||||
minor_to_major[i] -= adjustment;
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<int64> dims(minor_to_major.size());
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
DCHECK(minor_to_major.size() == dims.size() &&
|
||||
std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
|
||||
dims.begin()));
|
||||
}
|
||||
Shape stripped_shape;
|
||||
if (LayoutUtil::IsDenseArray(shape)) {
|
||||
stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes,
|
||||
minor_to_major);
|
||||
} else if (LayoutUtil::IsSparseArray(shape)) {
|
||||
stripped_shape =
|
||||
MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes,
|
||||
shape.layout().max_sparse_elements());
|
||||
} else {
|
||||
stripped_shape = MakeShape(shape.element_type(), dimension_sizes);
|
||||
}
|
||||
|
||||
VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
|
||||
VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
|
||||
return stripped_shape;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper for ForEachSubshape which visits the subshapes of the given shape in
|
||||
|
@ -510,26 +510,6 @@ class ShapeUtil {
|
||||
static Status ForEachMutableSubshapeWithStatus(
|
||||
Shape* shape, const MutatingStatusVisitorFunction& func);
|
||||
|
||||
// Removes all degenerate dimensions (size one) from the given shape. The
|
||||
// stripped minor_to_major preserves the relative ordering of non-degenerate
|
||||
// dimensions. The stripped shape has the property that the underlying
|
||||
// representation (bits in memory) for the stripped shape is the same as the
|
||||
// original shape modulo padding. Examples:
|
||||
//
|
||||
// input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2}
|
||||
// stripped shape: F32 [2], minor_to_major = {0}
|
||||
//
|
||||
// input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1}
|
||||
// stripped shape: F32 [6, 5], minor_to_major = {1, 0}
|
||||
//
|
||||
// input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1}
|
||||
// stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1}
|
||||
//
|
||||
// input shape: F32 [1, 1], minor_to_major = {0, 1}
|
||||
// stripped shape: F32 [], minor_to_major = {}
|
||||
// Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
|
||||
static Shape StripDegenerateDimensions(const Shape& shape);
|
||||
|
||||
// Permutes the dimensions by the given permutation, so
|
||||
// return_value.dimensions[permutation[i]] = argument.dimensions[i]
|
||||
static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||
|
@ -742,16 +742,6 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, StripDegenerateDimensions) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions(
|
||||
ShapeUtil::MakeShape(F32, {3, 1, 2})),
|
||||
ShapeUtil::MakeShape(F32, {3, 2})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
ShapeUtil::StripDegenerateDimensions(
|
||||
ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)),
|
||||
ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10)));
|
||||
}
|
||||
|
||||
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
|
||||
EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
|
||||
|
Loading…
Reference in New Issue
Block a user