[XLA] Add a few helper functions around dynamic dimensions.
PiperOrigin-RevId: 282408969 Change-Id: I9058ee17ae239d24ed741246af4288905be10212
This commit is contained in:
parent
b1e67d9456
commit
827b2f4723
@ -77,6 +77,10 @@ class Shape {
|
||||
return dynamic_dimensions_;
|
||||
}
|
||||
|
||||
absl::Span<bool> mutable_dynamic_dimensions() {
|
||||
return absl::MakeSpan(dynamic_dimensions_);
|
||||
}
|
||||
|
||||
// Add dimension_upper_bound().
|
||||
|
||||
// Removes the given dimension form the shape. Layout, if it exists, is
|
||||
@ -127,6 +131,19 @@ class Shape {
|
||||
Layout* mutable_layout() { return &layout_; }
|
||||
void clear_layout() { layout_.Clear(); }
|
||||
|
||||
// Recursively clear dynamic dimension of a shape.
|
||||
void clear_dynamic_dimensions() {
|
||||
if (!IsTuple()) {
|
||||
for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) {
|
||||
dynamic_dimensions_[i] = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (auto& subshape : tuple_shapes_) {
|
||||
subshape.clear_dynamic_dimensions();
|
||||
}
|
||||
}
|
||||
|
||||
void Swap(Shape* other) {
|
||||
using std::swap;
|
||||
swap(*this, *other);
|
||||
|
@ -176,6 +176,10 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return MakeValidatedShape(element_type, dimensions).ValueOrDie();
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
|
||||
return MakeShape(element_type, {});
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShape(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
||||
const std::vector<bool>& dynamic_dimensions) {
|
||||
@ -183,6 +187,13 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
|
||||
const Shape& shape) {
|
||||
Shape output = shape;
|
||||
output.clear_dynamic_dimensions();
|
||||
return output;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
||||
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
|
||||
@ -242,6 +253,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
shape.layout().tiles().begin(), shape.layout().tiles().end());
|
||||
new_shape.mutable_layout()->set_element_size_in_bits(
|
||||
shape.layout().element_size_in_bits());
|
||||
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
||||
new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
|
||||
}
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
@ -294,6 +308,20 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
*tuple_shape->mutable_tuple_shapes(index) = shape;
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
|
||||
ShapeIndexView index,
|
||||
int64 dim,
|
||||
bool is_dynamic) {
|
||||
if (index.empty()) {
|
||||
CHECK(!shape->IsTuple());
|
||||
shape->set_dynamic_dimension(dim, is_dynamic);
|
||||
return;
|
||||
}
|
||||
|
||||
UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
|
||||
index.ConsumeFront(), dim, is_dynamic);
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
|
||||
CHECK(LayoutUtil::IsDenseArray(*shape));
|
||||
shape->mutable_layout()->add_minor_to_major(shape->rank());
|
||||
|
@ -368,6 +368,10 @@ class ShapeUtil {
|
||||
static void UpdateTupleShape(const Shape& shape, int64 index,
|
||||
Shape* tuple_shape);
|
||||
|
||||
// Update the dynamic dimension for a shape. This shape can be a nested tuple.
|
||||
static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
|
||||
int64 dim, bool is_dynamic);
|
||||
|
||||
// Appends a major dimension to the shape with the given bound.
|
||||
static void AppendMajorDimension(int bound, Shape* shape);
|
||||
|
||||
@ -394,6 +398,9 @@ class ShapeUtil {
|
||||
absl::Span<const int64> dimensions,
|
||||
const std::vector<bool>& dynamic_dimensions);
|
||||
|
||||
// Make a scalar shape with given primitive type.
|
||||
static Shape MakeScalarShape(PrimitiveType element_type);
|
||||
|
||||
// Constructs a new shape with the given element type and sequence of
|
||||
// dimensions. Method checks if the element type is valid and the shape's
|
||||
// size fits in std::numeric_limits<int64>::max().
|
||||
@ -424,6 +431,9 @@ class ShapeUtil {
|
||||
absl::Span<const int64> dimensions,
|
||||
int64 max_sparse_elements);
|
||||
|
||||
// Returns the same shape except with all dimensions set to be static.
|
||||
static Shape MakeShapeWithStaticDimensions(const Shape& shape);
|
||||
|
||||
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
|
||||
static Shape MakeShapeWithDescendingLayout(
|
||||
PrimitiveType element_type, absl::Span<const int64> dimensions);
|
||||
|
@ -732,6 +732,15 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
|
||||
} while (std::next_permutation(layout.begin(), layout.end()));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, UpdateDynamicDimensions) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
|
||||
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape});
|
||||
|
||||
ShapeUtil::UpdateDynamicDimension(&tuple_shape, {0}, 1, true);
|
||||
EXPECT_TRUE(ShapeUtil::GetSubshape(tuple_shape, {0}).is_dynamic_dimension(1));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, PermuteDynamicDimensions) {
|
||||
Shape shape =
|
||||
ShapeUtil::MakeShape(F32, {10, 100, 1000},
|
||||
|
Loading…
Reference in New Issue
Block a user