[XLA] Add a few helper functions around dynamic dimensions.
PiperOrigin-RevId: 282408969 Change-Id: I9058ee17ae239d24ed741246af4288905be10212
This commit is contained in:
parent
b1e67d9456
commit
827b2f4723
@ -77,6 +77,10 @@ class Shape {
|
|||||||
return dynamic_dimensions_;
|
return dynamic_dimensions_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Span<bool> mutable_dynamic_dimensions() {
|
||||||
|
return absl::MakeSpan(dynamic_dimensions_);
|
||||||
|
}
|
||||||
|
|
||||||
// Add dimension_upper_bound().
|
// Add dimension_upper_bound().
|
||||||
|
|
||||||
// Removes the given dimension form the shape. Layout, if it exists, is
|
// Removes the given dimension form the shape. Layout, if it exists, is
|
||||||
@ -127,6 +131,19 @@ class Shape {
|
|||||||
Layout* mutable_layout() { return &layout_; }
|
Layout* mutable_layout() { return &layout_; }
|
||||||
void clear_layout() { layout_.Clear(); }
|
void clear_layout() { layout_.Clear(); }
|
||||||
|
|
||||||
|
// Recursively clear dynamic dimension of a shape.
|
||||||
|
void clear_dynamic_dimensions() {
|
||||||
|
if (!IsTuple()) {
|
||||||
|
for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) {
|
||||||
|
dynamic_dimensions_[i] = false;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto& subshape : tuple_shapes_) {
|
||||||
|
subshape.clear_dynamic_dimensions();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Swap(Shape* other) {
|
void Swap(Shape* other) {
|
||||||
using std::swap;
|
using std::swap;
|
||||||
swap(*this, *other);
|
swap(*this, *other);
|
||||||
|
@ -176,6 +176,10 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
|||||||
return MakeValidatedShape(element_type, dimensions).ValueOrDie();
|
return MakeValidatedShape(element_type, dimensions).ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
|
||||||
|
return MakeShape(element_type, {});
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ Shape ShapeUtil::MakeShape(
|
/* static */ Shape ShapeUtil::MakeShape(
|
||||||
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
||||||
const std::vector<bool>& dynamic_dimensions) {
|
const std::vector<bool>& dynamic_dimensions) {
|
||||||
@ -183,6 +187,13 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
|||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
|
||||||
|
const Shape& shape) {
|
||||||
|
Shape output = shape;
|
||||||
|
output.clear_dynamic_dimensions();
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
||||||
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
||||||
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
|
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
|
||||||
@ -242,6 +253,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
|||||||
shape.layout().tiles().begin(), shape.layout().tiles().end());
|
shape.layout().tiles().begin(), shape.layout().tiles().end());
|
||||||
new_shape.mutable_layout()->set_element_size_in_bits(
|
new_shape.mutable_layout()->set_element_size_in_bits(
|
||||||
shape.layout().element_size_in_bits());
|
shape.layout().element_size_in_bits());
|
||||||
|
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
||||||
|
new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
|
||||||
|
}
|
||||||
return new_shape;
|
return new_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +308,20 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
|||||||
*tuple_shape->mutable_tuple_shapes(index) = shape;
|
*tuple_shape->mutable_tuple_shapes(index) = shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
|
||||||
|
ShapeIndexView index,
|
||||||
|
int64 dim,
|
||||||
|
bool is_dynamic) {
|
||||||
|
if (index.empty()) {
|
||||||
|
CHECK(!shape->IsTuple());
|
||||||
|
shape->set_dynamic_dimension(dim, is_dynamic);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
|
||||||
|
index.ConsumeFront(), dim, is_dynamic);
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
|
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
|
||||||
CHECK(LayoutUtil::IsDenseArray(*shape));
|
CHECK(LayoutUtil::IsDenseArray(*shape));
|
||||||
shape->mutable_layout()->add_minor_to_major(shape->rank());
|
shape->mutable_layout()->add_minor_to_major(shape->rank());
|
||||||
|
@ -368,6 +368,10 @@ class ShapeUtil {
|
|||||||
static void UpdateTupleShape(const Shape& shape, int64 index,
|
static void UpdateTupleShape(const Shape& shape, int64 index,
|
||||||
Shape* tuple_shape);
|
Shape* tuple_shape);
|
||||||
|
|
||||||
|
// Update the dynamic dimension for a shape. This shape can be a nested tuple.
|
||||||
|
static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
|
||||||
|
int64 dim, bool is_dynamic);
|
||||||
|
|
||||||
// Appends a major dimension to the shape with the given bound.
|
// Appends a major dimension to the shape with the given bound.
|
||||||
static void AppendMajorDimension(int bound, Shape* shape);
|
static void AppendMajorDimension(int bound, Shape* shape);
|
||||||
|
|
||||||
@ -394,6 +398,9 @@ class ShapeUtil {
|
|||||||
absl::Span<const int64> dimensions,
|
absl::Span<const int64> dimensions,
|
||||||
const std::vector<bool>& dynamic_dimensions);
|
const std::vector<bool>& dynamic_dimensions);
|
||||||
|
|
||||||
|
// Make a scalar shape with given primitive type.
|
||||||
|
static Shape MakeScalarShape(PrimitiveType element_type);
|
||||||
|
|
||||||
// Constructs a new shape with the given element type and sequence of
|
// Constructs a new shape with the given element type and sequence of
|
||||||
// dimensions. Method checks if the element type is valid and the shape's
|
// dimensions. Method checks if the element type is valid and the shape's
|
||||||
// size fits in std::numeric_limits<int64>::max().
|
// size fits in std::numeric_limits<int64>::max().
|
||||||
@ -424,6 +431,9 @@ class ShapeUtil {
|
|||||||
absl::Span<const int64> dimensions,
|
absl::Span<const int64> dimensions,
|
||||||
int64 max_sparse_elements);
|
int64 max_sparse_elements);
|
||||||
|
|
||||||
|
// Returns the same shape except with all dimensions set to be static.
|
||||||
|
static Shape MakeShapeWithStaticDimensions(const Shape& shape);
|
||||||
|
|
||||||
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
|
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
|
||||||
static Shape MakeShapeWithDescendingLayout(
|
static Shape MakeShapeWithDescendingLayout(
|
||||||
PrimitiveType element_type, absl::Span<const int64> dimensions);
|
PrimitiveType element_type, absl::Span<const int64> dimensions);
|
||||||
|
@ -732,6 +732,15 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
|
|||||||
} while (std::next_permutation(layout.begin(), layout.end()));
|
} while (std::next_permutation(layout.begin(), layout.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ShapeUtilTest, UpdateDynamicDimensions) {
|
||||||
|
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
|
||||||
|
|
||||||
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape});
|
||||||
|
|
||||||
|
ShapeUtil::UpdateDynamicDimension(&tuple_shape, {0}, 1, true);
|
||||||
|
EXPECT_TRUE(ShapeUtil::GetSubshape(tuple_shape, {0}).is_dynamic_dimension(1));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ShapeUtilTest, PermuteDynamicDimensions) {
|
TEST(ShapeUtilTest, PermuteDynamicDimensions) {
|
||||||
Shape shape =
|
Shape shape =
|
||||||
ShapeUtil::MakeShape(F32, {10, 100, 1000},
|
ShapeUtil::MakeShape(F32, {10, 100, 1000},
|
||||||
|
Loading…
Reference in New Issue
Block a user