diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index fbf5647f2d8..e8178de3a00 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -77,6 +77,10 @@ class Shape { return dynamic_dimensions_; } + absl::Span mutable_dynamic_dimensions() { + return absl::MakeSpan(dynamic_dimensions_); + } + // Add dimension_upper_bound(). // Removes the given dimension form the shape. Layout, if it exists, is @@ -127,6 +131,19 @@ class Shape { Layout* mutable_layout() { return &layout_; } void clear_layout() { layout_.Clear(); } + // Recursively clear dynamic dimension of a shape. + void clear_dynamic_dimensions() { + if (!IsTuple()) { + for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) { + dynamic_dimensions_[i] = false; + } + return; + } + for (auto& subshape : tuple_shapes_) { + subshape.clear_dynamic_dimensions(); + } + } + void Swap(Shape* other) { using std::swap; swap(*this, *other); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 9c47303ad44..484673b8b6b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -176,6 +176,10 @@ StatusOr MakeShapeWithLayoutInternal( return MakeValidatedShape(element_type, dimensions).ValueOrDie(); } +/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { + return MakeShape(element_type, {}); +} + /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, absl::Span dimensions, const std::vector& dynamic_dimensions) { @@ -183,6 +187,13 @@ StatusOr MakeShapeWithLayoutInternal( .ValueOrDie(); } +/* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions( + const Shape& shape) { + Shape output = shape; + output.clear_dynamic_dimensions(); + return output; +} + /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)) << element_type; @@ -242,6 +253,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( shape.layout().tiles().begin(), shape.layout().tiles().end()); new_shape.mutable_layout()->set_element_size_in_bits( shape.layout().element_size_in_bits()); + for (int i = 0; i < shape.dimensions_size(); ++i) { + new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i)); + } return new_shape; } @@ -294,6 +308,20 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( *tuple_shape->mutable_tuple_shapes(index) = shape; } +/* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape, + ShapeIndexView index, + int64 dim, + bool is_dynamic) { + if (index.empty()) { + CHECK(!shape->IsTuple()); + shape->set_dynamic_dimension(dim, is_dynamic); + return; + } + + UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()), + index.ConsumeFront(), dim, is_dynamic); +} + /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); shape->mutable_layout()->add_minor_to_major(shape->rank()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index a4bddc864c8..668274ae714 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -368,6 +368,10 @@ class ShapeUtil { static void UpdateTupleShape(const Shape& shape, int64 index, Shape* tuple_shape); + // Update the dynamic dimension for a shape. This shape can be a nested tuple. + static void UpdateDynamicDimension(Shape* shape, ShapeIndexView index, + int64 dim, bool is_dynamic); + // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); @@ -394,6 +398,9 @@ class ShapeUtil { absl::Span dimensions, const std::vector& dynamic_dimensions); + // Make a scalar shape with given primitive type. + static Shape MakeScalarShape(PrimitiveType element_type); + // Constructs a new shape with the given element type and sequence of // dimensions. Method checks if the element type is valid and the shape's // size fits in std::numeric_limits::max(). @@ -424,6 +431,9 @@ class ShapeUtil { absl::Span dimensions, int64 max_sparse_elements); + // Returns the same shape except with all dimensions set to be static. + static Shape MakeShapeWithStaticDimensions(const Shape& shape); + // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, absl::Span dimensions); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4a59fe794c7..4e2030667ee 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -732,6 +732,15 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { } while (std::next_permutation(layout.begin(), layout.end())); } +TEST(ShapeUtilTest, UpdateDynamicDimensions) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape}); + + ShapeUtil::UpdateDynamicDimension(&tuple_shape, {0}, 1, true); + EXPECT_TRUE(ShapeUtil::GetSubshape(tuple_shape, {0}).is_dynamic_dimension(1)); +} + TEST(ShapeUtilTest, PermuteDynamicDimensions) { Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000},