[XLA] Add a few helper functions around dynamic dimensions.

PiperOrigin-RevId: 282408969
Change-Id: I9058ee17ae239d24ed741246af4288905be10212
This commit is contained in:
Yunxing Dai 2019-11-25 12:20:07 -08:00 committed by TensorFlower Gardener
parent b1e67d9456
commit 827b2f4723
4 changed files with 64 additions and 0 deletions

View File

@ -77,6 +77,10 @@ class Shape {
return dynamic_dimensions_;
}
absl::Span<bool> mutable_dynamic_dimensions() {
return absl::MakeSpan(dynamic_dimensions_);
}
// Add dimension_upper_bound().
// Removes the given dimension form the shape. Layout, if it exists, is
@ -127,6 +131,19 @@ class Shape {
Layout* mutable_layout() { return &layout_; }
void clear_layout() { layout_.Clear(); }
// Recursively clear dynamic dimension of a shape.
void clear_dynamic_dimensions() {
if (!IsTuple()) {
for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) {
dynamic_dimensions_[i] = false;
}
return;
}
for (auto& subshape : tuple_shapes_) {
subshape.clear_dynamic_dimensions();
}
}
void Swap(Shape* other) {
using std::swap;
swap(*this, *other);

View File

@ -176,6 +176,10 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return MakeValidatedShape(element_type, dimensions).ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
return MakeShape(element_type, {});
}
/* static */ Shape ShapeUtil::MakeShape(
PrimitiveType element_type, absl::Span<const int64> dimensions,
const std::vector<bool>& dynamic_dimensions) {
@ -183,6 +187,13 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
.ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
const Shape& shape) {
Shape output = shape;
output.clear_dynamic_dimensions();
return output;
}
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
PrimitiveType element_type, absl::Span<const int64> dimensions) {
CHECK(IsArrayPrimitiveType(element_type)) << element_type;
@ -242,6 +253,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
shape.layout().tiles().begin(), shape.layout().tiles().end());
new_shape.mutable_layout()->set_element_size_in_bits(
shape.layout().element_size_in_bits());
for (int i = 0; i < shape.dimensions_size(); ++i) {
new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
}
return new_shape;
}
@ -294,6 +308,20 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
*tuple_shape->mutable_tuple_shapes(index) = shape;
}
/* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
ShapeIndexView index,
int64 dim,
bool is_dynamic) {
if (index.empty()) {
CHECK(!shape->IsTuple());
shape->set_dynamic_dimension(dim, is_dynamic);
return;
}
UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
index.ConsumeFront(), dim, is_dynamic);
}
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
CHECK(LayoutUtil::IsDenseArray(*shape));
shape->mutable_layout()->add_minor_to_major(shape->rank());

View File

@ -368,6 +368,10 @@ class ShapeUtil {
static void UpdateTupleShape(const Shape& shape, int64 index,
Shape* tuple_shape);
// Update the dynamic dimension for a shape. This shape can be a nested tuple.
static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index,
int64 dim, bool is_dynamic);
// Appends a major dimension to the shape with the given bound.
static void AppendMajorDimension(int bound, Shape* shape);
@ -394,6 +398,9 @@ class ShapeUtil {
absl::Span<const int64> dimensions,
const std::vector<bool>& dynamic_dimensions);
// Make a scalar shape with given primitive type.
static Shape MakeScalarShape(PrimitiveType element_type);
// Constructs a new shape with the given element type and sequence of
// dimensions. Method checks if the element type is valid and the shape's
// size fits in std::numeric_limits<int64>::max().
@ -424,6 +431,9 @@ class ShapeUtil {
absl::Span<const int64> dimensions,
int64 max_sparse_elements);
// Returns the same shape except with all dimensions set to be static.
static Shape MakeShapeWithStaticDimensions(const Shape& shape);
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
static Shape MakeShapeWithDescendingLayout(
PrimitiveType element_type, absl::Span<const int64> dimensions);

View File

@ -732,6 +732,15 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
} while (std::next_permutation(layout.begin(), layout.end()));
}
TEST(ShapeUtilTest, UpdateDynamicDimensions) {
Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape});
ShapeUtil::UpdateDynamicDimension(&tuple_shape, {0}, 1, true);
EXPECT_TRUE(ShapeUtil::GetSubshape(tuple_shape, {0}).is_dynamic_dimension(1));
}
TEST(ShapeUtilTest, PermuteDynamicDimensions) {
Shape shape =
ShapeUtil::MakeShape(F32, {10, 100, 1000},