Most users seem to have wanted the alternate convention in the first place. PiperOrigin-RevId: 356810248 Change-Id: Iadbfc87129597b916a502abedcc6efe6f9fd926e
1743 lines
62 KiB
C++
1743 lines
62 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <numeric>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "absl/container/inlined_vector.h"
|
|
#include "absl/strings/ascii.h"
|
|
#include "absl/strings/numbers.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "absl/strings/str_split.h"
|
|
#include "absl/strings/string_view.h"
|
|
#include "absl/strings/strip.h"
|
|
#include "absl/types/optional.h"
|
|
#include "tensorflow/compiler/xla/index_util.h"
|
|
#include "tensorflow/compiler/xla/layout_util.h"
|
|
#include "tensorflow/compiler/xla/overflow_util.h"
|
|
#include "tensorflow/compiler/xla/permutation_util.h"
|
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/gtl/iterator_range.h"
|
|
#include "tensorflow/core/lib/hash/hash.h"
|
|
#include "tensorflow/core/lib/strings/numbers.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/core/platform/protobuf.h"
|
|
#include "tensorflow/core/platform/regexp.h"
|
|
|
|
namespace xla {
|
|
|
|
using absl::StrAppend;
|
|
using absl::StrCat;
|
|
|
|
namespace {
|
|
// An array that is indexed by PrimitiveType, and returns
|
|
// the size of each element of that primitive type, or 0
|
|
// if the PrimitiveType is not a primitive type
|
|
constexpr uint8 primitive_byte_size[PrimitiveType_ARRAYSIZE] = {
|
|
0, // PRIMITIVE_TYPE_INVALID = 0,
|
|
sizeof(int8), // PRED = 1
|
|
sizeof(int8), // S8 = 2
|
|
sizeof(int16), // S16 = 3
|
|
sizeof(int32), // S32 = 4
|
|
sizeof(int64), // S64 = 5
|
|
sizeof(uint8), // U8 = 6
|
|
sizeof(uint16), // U16 = 7
|
|
sizeof(uint32), // U32 = 8
|
|
sizeof(uint64), // U64 = 9
|
|
sizeof(float) / 2, // F16 = 10
|
|
sizeof(float), // F32 = 11
|
|
sizeof(double), // F64 = 12
|
|
0, // TUPLE = 13
|
|
0, // OPAQUE_TYPE = 14
|
|
sizeof(complex64), // C64 = 15
|
|
sizeof(float) / 2, // BF16 = 16
|
|
0, // TOKEN = 17
|
|
sizeof(complex128) // C128 = 18
|
|
};
|
|
} // namespace
|
|
|
|
string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
|
|
|
|
string ShapeIndexView::ToString() const {
|
|
return StrCat("{", absl::StrJoin(indices_, ","), "}");
|
|
}
|
|
|
|
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
|
|
return indices_ == other.indices_;
|
|
}
|
|
|
|
bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
|
|
return !(*this == other);
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
|
|
out << shape_index.ToString();
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
|
|
out << shape_index.ToString();
|
|
return out;
|
|
}
|
|
|
|
bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
|
|
return size() >= prefix.size() &&
|
|
indices_.subspan(0, prefix.size()) == prefix.indices_;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::IsArrayPrimitiveType(
|
|
PrimitiveType primitive_type) {
|
|
return primitive_util::IsArrayType(primitive_type);
|
|
}
|
|
|
|
namespace {
|
|
// Constructs and returns the new shape with the given minor_to_major order in
|
|
// its Layout.
|
|
StatusOr<Shape> MakeShapeWithLayoutInternal(
|
|
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
|
absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
|
|
int64 element_size_in_bits, int64 memory_space) {
|
|
if (dimensions.size() != minor_to_major.size()) {
|
|
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
|
|
dimensions.size(), minor_to_major.size());
|
|
}
|
|
if (element_type == OPAQUE_TYPE || element_type == TUPLE) {
|
|
return InvalidArgument("Unsupported element type: %s",
|
|
PrimitiveType_Name(element_type));
|
|
}
|
|
TF_ASSIGN_OR_RETURN(Shape shape,
|
|
ShapeUtil::MakeValidatedShape(element_type, dimensions));
|
|
if (element_size_in_bits ==
|
|
ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) {
|
|
// Only set element_size_in_bits if it's different from the default value.
|
|
element_size_in_bits = 0;
|
|
}
|
|
*shape.mutable_layout() = LayoutUtil::MakeLayout(
|
|
minor_to_major, tiles, element_size_in_bits, memory_space);
|
|
if (!shape.has_layout()) {
|
|
return InvalidArgument("Shape has no layout.");
|
|
}
|
|
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
|
|
return shape;
|
|
}
|
|
} // namespace
|
|
|
|
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
|
|
bool equal = Shape::Equal()(lhs, rhs);
|
|
|
|
if (!equal && VLOG_IS_ON(3)) {
|
|
VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
|
|
<< ", rhs = " << rhs.ShortDebugString();
|
|
}
|
|
|
|
return equal;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::EqualIgnoringElementType(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
bool equal = Shape::Equal().IgnoreElementType()(lhs, rhs);
|
|
if (!equal && VLOG_IS_ON(3)) {
|
|
VLOG(3) << "ShapeUtil::EqualIgnoringElementType differ: lhs = "
|
|
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
|
|
}
|
|
|
|
return equal;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs);
|
|
if (!equal && VLOG_IS_ON(3)) {
|
|
VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
|
|
<< lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
|
|
}
|
|
|
|
return equal;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
bool equal = true;
|
|
ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
|
equal &= IndexIsValid(rhs, index);
|
|
});
|
|
ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
|
equal &= IndexIsValid(lhs, index);
|
|
});
|
|
|
|
return equal;
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
|
|
int64 accum = 0;
|
|
for (int64 dimension : shape.dimensions()) {
|
|
// We do not count zero dimensions.
|
|
if (dimension != 1) {
|
|
accum += 1;
|
|
}
|
|
}
|
|
return accum;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::FillNewShape(PrimitiveType element_type,
|
|
absl::Span<const int64> dimensions,
|
|
Shape* shape) {
|
|
const int eint = static_cast<int>(element_type);
|
|
int64 dense_shape_size = ((eint >= 0 && eint < PrimitiveType_ARRAYSIZE)
|
|
? primitive_byte_size[eint]
|
|
: 0); // Out of range: force a failure
|
|
if (dense_shape_size <= 0) {
|
|
return false;
|
|
}
|
|
|
|
// Verify that array-based lookup is consistent with public API.
|
|
DCHECK_EQ(dense_shape_size, ByteSizeOfPrimitiveType(element_type))
|
|
<< element_type;
|
|
|
|
shape->set_element_type(element_type);
|
|
const int ndims = dimensions.size();
|
|
auto layout = shape->mutable_layout();
|
|
layout->set_format(DENSE);
|
|
auto* minor_to_major = layout->mutable_minor_to_major();
|
|
for (int i = 0; i < ndims; i++) {
|
|
const int64 d = dimensions[i];
|
|
if (d < 0) {
|
|
return false;
|
|
}
|
|
dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
|
|
if (dense_shape_size < 0) {
|
|
return false;
|
|
}
|
|
|
|
shape->add_dimensions(d);
|
|
minor_to_major->push_back(ndims - 1 - i);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/* static */ ProgramShape ShapeUtil::MakeProgramShape(
|
|
std::initializer_list<Shape> parameters, Shape result) {
|
|
ProgramShape program_shape;
|
|
for (const Shape& shape : parameters) {
|
|
*program_shape.add_parameters() = shape;
|
|
}
|
|
*program_shape.mutable_result() = std::move(result);
|
|
return program_shape;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
|
|
absl::Span<const int64> dimensions) {
|
|
Shape shape;
|
|
CHECK(FillNewShape(element_type, dimensions, &shape));
|
|
return shape;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) {
|
|
return MakeShape(element_type, {});
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeShape(
|
|
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
|
const std::vector<bool>& dynamic_dimensions) {
|
|
return MakeValidatedShape(element_type, dimensions, dynamic_dimensions)
|
|
.ValueOrDie();
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeShapeWithStaticDimensions(
|
|
const Shape& shape) {
|
|
Shape output = shape;
|
|
output.clear_dynamic_dimensions();
|
|
return output;
|
|
}
|
|
|
|
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
|
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
|
Shape shape;
|
|
if (!FillNewShape(element_type, dimensions, &shape)) {
|
|
return InvalidArgument("invalid shape type=%d, dims=[%s]",
|
|
static_cast<int>(element_type),
|
|
absl::StrJoin(dimensions, ","));
|
|
}
|
|
return shape;
|
|
}
|
|
|
|
/* static */ StatusOr<Shape> ShapeUtil::MakeValidatedShape(
|
|
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
|
const std::vector<bool>& dynamic_dimensions) {
|
|
if (dynamic_dimensions.size() != dimensions.size()) {
|
|
return InvalidArgument(
|
|
"dynamic dimensions size %d did not match number of dimensions %d",
|
|
dynamic_dimensions.size(), dimensions.size());
|
|
}
|
|
|
|
Shape shape;
|
|
if (!FillNewShape(element_type, dimensions, &shape)) {
|
|
return InvalidArgument("invalid shape type=%d, dims=[%s]",
|
|
static_cast<int>(element_type),
|
|
absl::StrJoin(dimensions, ","));
|
|
}
|
|
for (int i = 0, n = dimensions.size(); i < n; i++) {
|
|
shape.set_dynamic_dimension(i, dynamic_dimensions[i]);
|
|
}
|
|
return shape;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeShapeWithLayout(
|
|
PrimitiveType element_type, absl::Span<const int64> dimensions,
|
|
absl::Span<const int64> minor_to_major, absl::Span<const Tile> tiles,
|
|
int64 element_size_in_bits, int64 memory_space) {
|
|
auto ret =
|
|
MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major,
|
|
tiles, element_size_in_bits, memory_space);
|
|
if (!ret.ok()) LOG(ERROR) << ret.status();
|
|
return ret.ValueOrDie();
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
|
|
PrimitiveType element_type, absl::Span<const int64> dimensions) {
|
|
std::vector<int64> layout(dimensions.size());
|
|
std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
|
|
return MakeShapeWithLayout(element_type, dimensions, layout);
|
|
}
|
|
|
|
/* static */ Shape
|
|
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
|
const Shape& shape) {
|
|
std::vector<int64> dims(shape.dimensions_size());
|
|
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
|
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
|
|
}
|
|
Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims);
|
|
// Since the physical layout is kept the same, the tiles and element size are
|
|
// the same also.
|
|
new_shape.mutable_layout()->mutable_tiles()->assign(
|
|
shape.layout().tiles().begin(), shape.layout().tiles().end());
|
|
new_shape.mutable_layout()->set_element_size_in_bits(
|
|
shape.layout().element_size_in_bits());
|
|
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
|
new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i));
|
|
}
|
|
return new_shape;
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type,
|
|
absl::Span<const int64> dimensions,
|
|
Shape* shape) {
|
|
shape->Clear();
|
|
shape->set_element_type(element_type);
|
|
for (int64 dimension : dimensions) {
|
|
shape->add_dimensions(dimension);
|
|
}
|
|
LayoutUtil::SetToDefaultLayout(shape);
|
|
return ValidateShape(*shape);
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
|
|
Shape result = original;
|
|
result.clear_dynamic_dimensions();
|
|
return result;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
|
|
Shape result;
|
|
result.set_element_type(TUPLE);
|
|
result.mutable_tuple_shapes()->reserve(shapes.size());
|
|
for (const auto& shape : shapes) {
|
|
AppendShapeToTuple(shape, &result);
|
|
}
|
|
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
|
|
return result;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeOpaqueShape() {
|
|
Shape result;
|
|
result.set_element_type(OPAQUE_TYPE);
|
|
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
|
|
return result;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::MakeTokenShape() {
|
|
Shape result;
|
|
result.set_element_type(TOKEN);
|
|
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
|
|
return result;
|
|
}
|
|
|
|
/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
|
|
Shape* tuple_shape) {
|
|
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
|
|
*tuple_shape->add_tuple_shapes() = shape;
|
|
}
|
|
|
|
/* static */ void ShapeUtil::UpdateTupleShape(const Shape& shape, int64 index,
|
|
Shape* tuple_shape) {
|
|
CHECK(index < tuple_shape->tuple_shapes_size());
|
|
*tuple_shape->mutable_tuple_shapes(index) = shape;
|
|
}
|
|
|
|
/* static */ void ShapeUtil::UpdateDynamicDimension(Shape* shape,
|
|
ShapeIndexView index,
|
|
int64 dim,
|
|
bool is_dynamic) {
|
|
if (index.empty()) {
|
|
CHECK(!shape->IsTuple());
|
|
shape->set_dynamic_dimension(dim, is_dynamic);
|
|
return;
|
|
}
|
|
|
|
UpdateDynamicDimension(shape->mutable_tuple_shapes(index.front()),
|
|
index.ConsumeFront(), dim, is_dynamic);
|
|
}
|
|
|
|
/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
|
|
CHECK(LayoutUtil::IsDenseArray(*shape));
|
|
shape->mutable_layout()->add_minor_to_major(shape->rank());
|
|
shape->add_dimensions(bound);
|
|
TF_DCHECK_OK(ValidateShape(*shape));
|
|
}
|
|
|
|
/* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to,
|
|
const Shape& from) {
|
|
CHECK_EQ(to->rank(), from.rank());
|
|
for (int64 i = 0; i < from.rank(); ++i) {
|
|
to->set_dynamic_dimension(i, from.is_dynamic_dimension(i));
|
|
}
|
|
TF_DCHECK_OK(ValidateShape(*to));
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
|
|
return primitive_util::IsIntegralType(shape.element_type());
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
|
|
int32 bits) {
|
|
return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
|
|
if (!shape.IsArray()) {
|
|
return false;
|
|
}
|
|
return primitive_util::BitWidth(shape.element_type()) == bits;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
|
|
switch (shape.element_type()) {
|
|
case S8:
|
|
case S16:
|
|
case S32:
|
|
case S64:
|
|
case F16:
|
|
case BF16:
|
|
case F32:
|
|
case F64:
|
|
return true;
|
|
|
|
case PRED:
|
|
case U8:
|
|
case U16:
|
|
case U32:
|
|
case U64:
|
|
case C64:
|
|
case C128:
|
|
case TUPLE:
|
|
case OPAQUE_TYPE:
|
|
case TOKEN:
|
|
return false;
|
|
|
|
default:
|
|
LOG(FATAL) << "Unhandled element type " << shape.element_type();
|
|
}
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
|
|
return primitive_util::IsComplexType(shape.element_type());
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
|
|
return primitive_util::IsFloatingPointType(shape.element_type());
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
|
|
return shape.IsTuple() &&
|
|
absl::c_any_of(shape.tuple_shapes(),
|
|
[](const Shape& s) { return s.IsTuple(); });
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
|
|
return shape.IsTuple() && TupleElementCount(shape) == 0;
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
|
|
CHECK(shape.IsTuple()) << HumanString(shape);
|
|
return shape.tuple_shapes_size();
|
|
}
|
|
|
|
/* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
|
|
int64 index) {
|
|
CHECK(shape.IsTuple());
|
|
CHECK_GT(TupleElementCount(shape), index);
|
|
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
|
|
return shape.tuple_shapes(index);
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
|
|
int64 n = 0;
|
|
ForEachSubshape(shape, [&](const Shape& literal_subshape,
|
|
const ShapeIndex& index) { ++n; });
|
|
return n;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
|
|
int64 limit) {
|
|
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
|
|
CHECK(tuple.IsTuple());
|
|
CHECK_LE(start, TupleElementCount(tuple));
|
|
CHECK_LE(limit, TupleElementCount(tuple));
|
|
|
|
std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
|
|
tuple.tuple_shapes().begin() + limit);
|
|
return MakeTupleShape(new_elements);
|
|
}
|
|
|
|
// Returns the shape of a real or imaginary component.
|
|
/* static */ Shape ShapeUtil::ComplexComponentShape(
|
|
const Shape& complex_shape) {
|
|
CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape);
|
|
return ChangeElementType(complex_shape, primitive_util::ComplexComponentType(
|
|
complex_shape.element_type()));
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
|
|
DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape);
|
|
DCHECK_EQ(shape.dimensions_size(), shape.rank());
|
|
if (shape.dimensions().size() == 1) {
|
|
return shape.dimensions()[0];
|
|
}
|
|
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
|
|
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
|
|
std::multiplies<int64>());
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
|
|
CHECK(shape.IsArray() || shape.IsTuple());
|
|
if (shape.IsArray()) {
|
|
return ElementsIn(shape);
|
|
}
|
|
int64 count = 0;
|
|
for (const Shape& element_shape : shape.tuple_shapes()) {
|
|
count += ElementsInRecursive(element_shape);
|
|
}
|
|
return count;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
|
|
PrimitiveType primitive_type) {
|
|
if (shape.element_type() == primitive_type) {
|
|
return true;
|
|
}
|
|
for (const Shape& element_shape : shape.tuple_shapes()) {
|
|
if (HasPrimitiveType(element_shape, primitive_type)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
|
|
return shape.IsArray() && ElementsIn(shape) == 0;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::IsScalarWithElementType(
|
|
const Shape& shape, PrimitiveType element_type) {
|
|
return IsScalar(shape) && shape.element_type() == element_type;
|
|
}
|
|
|
|
/* static */ string ShapeUtil::HumanString(const Shape& shape) {
|
|
if (shape.IsTuple()) {
|
|
string text = "(";
|
|
const char* prefix = "";
|
|
for (const Shape& elem_shape : shape.tuple_shapes()) {
|
|
StrAppend(&text, prefix, HumanString(elem_shape));
|
|
prefix = ", ";
|
|
}
|
|
text += ")";
|
|
return text;
|
|
}
|
|
std::vector<string> dim_elements;
|
|
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
|
if (shape.is_dynamic_dimension(i)) {
|
|
dim_elements.push_back(StrCat("<=", shape.dimensions(i)));
|
|
} else {
|
|
dim_elements.push_back(StrCat(shape.dimensions(i)));
|
|
}
|
|
}
|
|
return StrCat(
|
|
primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[",
|
|
absl::StrJoin(dim_elements, ","), "]");
|
|
}
|
|
|
|
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
|
|
if (shape.IsTuple()) {
|
|
string text = "(";
|
|
const char* prefix = "";
|
|
for (const Shape& elem_shape : shape.tuple_shapes()) {
|
|
StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
|
|
prefix = ", ";
|
|
}
|
|
text += ")";
|
|
return text;
|
|
}
|
|
string result = HumanString(shape);
|
|
if (IsScalar(shape)) {
|
|
string layout_str = LayoutUtil::HumanString(shape.layout());
|
|
// Don't print "{}" as layout for scalars.
|
|
if (layout_str != "{}") {
|
|
StrAppend(&result, layout_str);
|
|
}
|
|
} else if (shape.IsArray() && LayoutUtil::HasLayout(shape)) {
|
|
StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
|
|
std::vector<string> parameters;
|
|
for (auto& shape : program_shape.parameters()) {
|
|
const int i = parameters.size();
|
|
parameters.push_back(StrCat(i < program_shape.parameter_names_size()
|
|
? program_shape.parameter_names(i)
|
|
: "(unknown)",
|
|
": ", HumanString(shape)));
|
|
}
|
|
return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
|
|
HumanString(program_shape.result()));
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
CHECK(lhs.IsArray());
|
|
CHECK(rhs.IsArray());
|
|
return absl::c_equal(lhs.dimensions(), rhs.dimensions());
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) {
|
|
CHECK(lhs.IsArray());
|
|
CHECK(rhs.IsArray());
|
|
return lhs.rank() == rhs.rank();
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
|
|
return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
return Shape::Equal()
|
|
.IgnoreDynamicDimension()
|
|
.IgnoreElementType()
|
|
.IgnoreLayout()(lhs, rhs);
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::CompatibleKind(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
return Shape::Equal()
|
|
.IgnoreElementType()
|
|
.IgnoreLayout()
|
|
.IgnoreDimensions()
|
|
.IgnoreDynamicDimension()(lhs, rhs);
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
|
|
const Shape& rhs) {
|
|
return Shape::Equal()
|
|
.IgnoreDynamicDimension()
|
|
.IgnoreFpPrecision()
|
|
.IgnoreLayout()(lhs, rhs);
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
|
|
int64 dimension_number) {
|
|
return shape.dimensions(GetDimensionNumber(shape, dimension_number));
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
|
|
int64 dimension_number) {
|
|
if (dimension_number < 0) {
|
|
dimension_number += shape.rank();
|
|
}
|
|
CHECK_GE(dimension_number, 0);
|
|
return dimension_number;
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
|
|
PrimitiveType primitive_type) {
|
|
switch (primitive_type) {
|
|
case PRED:
|
|
return sizeof(int8);
|
|
case S8:
|
|
return sizeof(int8);
|
|
case S16:
|
|
return sizeof(int16);
|
|
case S32:
|
|
return sizeof(int32);
|
|
case S64:
|
|
return sizeof(int64);
|
|
case U8:
|
|
return sizeof(uint8);
|
|
case U16:
|
|
return sizeof(uint16);
|
|
case U32:
|
|
return sizeof(uint32);
|
|
case U64:
|
|
return sizeof(uint64);
|
|
case BF16:
|
|
return sizeof(float) / 2;
|
|
case F16:
|
|
return sizeof(float) / 2;
|
|
case F32:
|
|
return sizeof(float);
|
|
case F64:
|
|
return sizeof(double);
|
|
case C64:
|
|
return sizeof(complex64);
|
|
case C128:
|
|
return sizeof(complex128);
|
|
case TOKEN:
|
|
// Tokens require no space.
|
|
return 0;
|
|
case TUPLE:
|
|
case OPAQUE_TYPE:
|
|
LOG(FATAL) << PrimitiveType_Name(primitive_type)
|
|
<< " primitive type has no definitive size";
|
|
default:
|
|
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
|
|
}
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
|
|
int64 pointer_size) {
|
|
TF_DCHECK_OK(ValidateShape(shape));
|
|
if (shape.element_type() == TUPLE) {
|
|
return ByteSizeOfTupleIndexTable(shape, pointer_size);
|
|
} else if (shape.IsArray()) {
|
|
return ByteSizeOfElements(shape);
|
|
} else if (shape.element_type() == TOKEN) {
|
|
return 0;
|
|
} else if (shape.element_type() == OPAQUE_TYPE) {
|
|
CHECK_GT(pointer_size, 0);
|
|
return pointer_size;
|
|
}
|
|
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
|
|
<< " primitive type has no definitive size";
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
|
|
int64 pointer_size) {
|
|
TF_DCHECK_OK(ValidateShape(shape));
|
|
CHECK_EQ(TUPLE, shape.element_type());
|
|
CHECK_GT(pointer_size, 0);
|
|
return pointer_size * shape.tuple_shapes_size();
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
|
|
TF_DCHECK_OK(ValidateShape(shape));
|
|
CHECK(shape.IsArray());
|
|
int64 allocated_element_count;
|
|
|
|
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
|
|
allocated_element_count = ElementsIn(shape);
|
|
return allocated_element_count *
|
|
ByteSizeOfPrimitiveType(shape.element_type());
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
|
|
const Shape& shape) {
|
|
if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
|
|
!PrimitiveType_IsValid(shape.element_type())) {
|
|
return InvalidArgument("shape has invalid element type: %s",
|
|
shape.ShortDebugString());
|
|
}
|
|
if (shape.element_type() == TUPLE) {
|
|
if (shape.dimensions_size() != 0) {
|
|
return InvalidArgument("tuples must not have dimensions specified");
|
|
}
|
|
for (auto& element_shape : shape.tuple_shapes()) {
|
|
TF_RETURN_IF_ERROR(
|
|
ValidateShapeWithOptionalLayoutInternal(element_shape));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Non-tuple shape.
|
|
if (shape.tuple_shapes_size() > 0) {
|
|
return InvalidArgument("non-tuple shape has tuple_shapes field");
|
|
}
|
|
|
|
// Tokens and opaques can should not have layout or dimensions.
|
|
if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) {
|
|
if (shape.dimensions_size() != 0) {
|
|
return InvalidArgument(
|
|
"shape has %s element type, but has dimensions field: %s",
|
|
primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
|
|
shape.ShortDebugString());
|
|
}
|
|
if (shape.has_layout()) {
|
|
return InvalidArgument(
|
|
"shape has %s element type, but has layout field: %s",
|
|
primitive_util::LowercasePrimitiveTypeName(shape.element_type()),
|
|
shape.ShortDebugString());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
for (int64 i = 0; i < shape.rank(); ++i) {
|
|
int64 dimension = shape.dimensions(i);
|
|
if (dimension < 0) {
|
|
return InvalidArgument(
|
|
"shape's dimensions must not be < 0; dimension at index %d was %d", i,
|
|
dimension);
|
|
}
|
|
}
|
|
|
|
TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
|
|
return Status::OK();
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
|
|
VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
|
|
|
|
if (!shape.IsArray()) {
|
|
return Status::OK();
|
|
}
|
|
|
|
int64 shape_size = [&]() {
|
|
int64 dense_shape_size = 1;
|
|
if (shape.dimensions().empty()) {
|
|
return dense_shape_size;
|
|
}
|
|
|
|
absl::Span<const int64> shape_max_dimensions =
|
|
AsInt64Slice(shape.dimensions());
|
|
for (int64 dim : shape_max_dimensions) {
|
|
dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim);
|
|
if (dense_shape_size < 0) {
|
|
return dense_shape_size;
|
|
}
|
|
}
|
|
dense_shape_size = MultiplyWithoutOverflow(
|
|
dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
|
|
return dense_shape_size;
|
|
}();
|
|
|
|
if (shape_size < 0) {
|
|
return InvalidArgument("Shape %s size may overflow int64.",
|
|
ShapeUtil::HumanString(shape));
|
|
}
|
|
|
|
VLOG(3) << "Shape size is valid: " << shape_size;
|
|
return Status::OK();
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
|
|
const Shape& shape) {
|
|
TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
|
|
|
|
return LayoutUtil::ValidateLayoutInShape(shape,
|
|
/*allow_missing_layouts=*/true);
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
|
|
TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
|
|
|
|
return LayoutUtil::ValidateLayoutInShape(shape);
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
|
|
PrimitiveType type) {
|
|
if (original.IsTuple()) {
|
|
std::vector<Shape> new_operands;
|
|
new_operands.reserve(original.tuple_shapes_size());
|
|
for (const Shape& operand : original.tuple_shapes()) {
|
|
new_operands.push_back(ChangeElementType(operand, type));
|
|
}
|
|
return MakeTupleShape(new_operands);
|
|
} else {
|
|
Shape new_shape = original;
|
|
new_shape.set_element_type(type);
|
|
return new_shape;
|
|
}
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
|
|
ShapeIndexView index) {
|
|
const Shape* subshape = &shape;
|
|
for (auto i : index) {
|
|
if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) {
|
|
return false;
|
|
}
|
|
subshape = &subshape->tuple_shapes(i);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
|
|
ShapeIndexView index) {
|
|
const Shape* return_shape = &shape;
|
|
for (auto i : index) {
|
|
CHECK(return_shape->IsTuple())
|
|
<< "Invalid index " << index << " for shape " << shape;
|
|
return_shape = &return_shape->tuple_shapes(i);
|
|
}
|
|
return *return_shape;
|
|
}
|
|
|
|
/* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
|
|
const Shape& shape, ShapeIndexView index) {
|
|
const Shape* return_shape = &shape;
|
|
for (auto i : index) {
|
|
if (!return_shape->IsTuple() || i < 0 ||
|
|
i >= return_shape->tuple_shapes_size()) {
|
|
return InvalidArgument(
|
|
"Shape index %s not a valid subshape index for tuple with shape %s",
|
|
index.ToString(), shape.DebugString());
|
|
}
|
|
return_shape = &return_shape->tuple_shapes(i);
|
|
}
|
|
return return_shape;
|
|
}
|
|
|
|
/* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
|
|
ShapeIndexView index) {
|
|
Shape* return_shape = shape;
|
|
for (auto i : index) {
|
|
CHECK(return_shape->IsTuple());
|
|
return_shape = return_shape->mutable_tuple_shapes(i);
|
|
}
|
|
return return_shape;
|
|
}
|
|
|
|
/* static */
|
|
bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
|
|
return !GetSubshape(shape, index).IsTuple();
|
|
}
|
|
|
|
/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
|
|
if (!shape.IsTuple()) {
|
|
return 1;
|
|
}
|
|
int64 count = 0;
|
|
for (const Shape& subshape : shape.tuple_shapes()) {
|
|
count += GetLeafCount(subshape);
|
|
}
|
|
return count;
|
|
}
|
|
|
|
/* static */ std::vector<ShapeUtil::IndexedShape> ShapeUtil::GetLeafShapes(
|
|
const Shape& shape) {
|
|
std::vector<IndexedShape> leaves;
|
|
ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
|
|
if (IsLeafIndex(shape, index)) {
|
|
leaves.emplace_back(index, sub_shape);
|
|
}
|
|
});
|
|
return leaves;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
|
|
CHECK(shape.IsArray());
|
|
return absl::c_linear_search(shape.dimensions(), 1);
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
|
|
return FilterDimensions(
|
|
[&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Helper for ForEachSubshape which visits the subshapes of the given shape in
|
|
// DFS pre-order starting with the index.
|
|
Status ForEachSubshapeHelper(const Shape& shape,
|
|
const ShapeUtil::StatusVisitorFunction& func,
|
|
ShapeIndex* index) {
|
|
TF_RETURN_IF_ERROR(func(shape, *index));
|
|
if (shape.IsTuple()) {
|
|
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
|
index->push_back(i);
|
|
TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
|
|
ShapeUtil::GetTupleElementShape(shape, i), func, index));
|
|
index->pop_back();
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Helper for ForEachMutableSubshape which visits the subshapes of the given
|
|
// shape in DFS pre-order starting with the index.
|
|
Status ForEachMutableSubshapeHelper(
|
|
Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func,
|
|
ShapeIndex* index) {
|
|
TF_RETURN_IF_ERROR(func(shape, *index));
|
|
if (shape->IsTuple()) {
|
|
for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
|
|
index->push_back(i);
|
|
TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
|
|
shape->mutable_tuple_shapes(i), func, index));
|
|
index->pop_back();
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/* static */ void ShapeUtil::ForEachSubshape(const Shape& shape,
|
|
const VisitorFunction& func) {
|
|
ShapeIndex index;
|
|
ForEachSubshapeHelper(
|
|
shape,
|
|
[&func](const Shape& subshape, const ShapeIndex& index) {
|
|
func(subshape, index);
|
|
return Status::OK();
|
|
},
|
|
&index)
|
|
.IgnoreError();
|
|
}
|
|
|
|
/* static */ void ShapeUtil::ForEachMutableSubshape(
|
|
Shape* shape, const MutatingVisitorFunction& func) {
|
|
ShapeIndex index;
|
|
ForEachMutableSubshapeHelper(
|
|
shape,
|
|
[&func](Shape* subshape, const ShapeIndex& index) {
|
|
func(subshape, index);
|
|
return Status::OK();
|
|
},
|
|
&index)
|
|
.IgnoreError();
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::ForEachSubshapeWithStatus(
|
|
const Shape& shape, const StatusVisitorFunction& func) {
|
|
ShapeIndex index;
|
|
return ForEachSubshapeHelper(shape, func, &index);
|
|
}
|
|
|
|
/* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus(
|
|
Shape* shape, const MutatingStatusVisitorFunction& func) {
|
|
ShapeIndex index;
|
|
return ForEachMutableSubshapeHelper(shape, func, &index);
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::PermuteDimensions(
|
|
absl::Span<const int64> permutation, const Shape& shape) {
|
|
Shape new_shape = shape;
|
|
new_shape.clear_dimensions();
|
|
for (auto dim : Permute(shape.dimensions(), permutation)) {
|
|
new_shape.add_dimensions(dim);
|
|
}
|
|
auto inv_permutation = InversePermutation(permutation);
|
|
for (int64 i = 0; i < shape.rank(); i++) {
|
|
new_shape.set_dynamic_dimension(inv_permutation[i],
|
|
shape.is_dynamic_dimension(i));
|
|
}
|
|
|
|
// If `shape` has a layout, by contract we choose a new layout such that the
|
|
// transpose defined by this permutation is a bitcast.
|
|
//
|
|
// Some formalism helps to understand the correct way to do this. We're going
|
|
// to do algebra in the group of permutations of the dimensions of `shape`.
|
|
//
|
|
// Since the order of `shape`'s dimensions is not permuted relative to itself,
|
|
// `shape`'s list of dimensions is isomorphic to the identity I.
|
|
//
|
|
// Let `shape`'s layout be L. A layout is a permutation which maps a
|
|
// minor-to-major physical dimension ordering to a shape's logical dimension
|
|
// ordering. Therefore the inverse of a layout maps from logical to physical
|
|
// dims, and so the physical ordering of I is simply L'.I = L', where L' is
|
|
// the inverse of L.
|
|
//
|
|
// Let the argument `permutation` be P. This is a permutation over `shape`'s
|
|
// dimensions, so our return value will be a shape with dims P.I = P. Our
|
|
// goal is to construct a layout permutation L* for this shape. The physical
|
|
// dimension ordering of this returned shape must be the same as that of the
|
|
// original shape, namely L'.
|
|
//
|
|
// Our returned shape has dims P and layout L*, so its in-memory ordering is
|
|
// L*'.P. Setting this equal to L' and solving for L*, we get:
|
|
//
|
|
// L*'.P = L' =>
|
|
// L*' = L'P' =>
|
|
// L* = P.L
|
|
//
|
|
if (shape.has_layout()) {
|
|
CHECK(LayoutUtil::IsDenseArray(shape));
|
|
Layout* new_layout = new_shape.mutable_layout();
|
|
new_layout->set_format(DENSE);
|
|
new_layout->clear_minor_to_major();
|
|
for (auto index : ComposePermutations(
|
|
inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
|
|
new_layout->add_minor_to_major(index);
|
|
}
|
|
// The permutation accepted by TransposeIsBitcast is the inverse of the
|
|
// permutation here.
|
|
CHECK(TransposeIsBitcast(shape, new_shape, permutation))
|
|
<< "shape=" << HumanStringWithLayout(shape)
|
|
<< ", new_shape=" << HumanStringWithLayout(new_shape)
|
|
<< ", permutation={" << absl::StrJoin(permutation, ",") << "}";
|
|
}
|
|
return new_shape;
|
|
}
|
|
|
|
/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
|
|
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
|
|
const Shape& shape_post) {
|
|
CHECK(shape_pre.IsArray());
|
|
CHECK(shape_post.IsArray());
|
|
|
|
auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
|
|
|
|
std::vector<int64> deleted_indices;
|
|
std::vector<int64> inserted_indices;
|
|
// Returns false if any input/output index between prior_unmodified_dim_pair
|
|
// and unmodified_dim_pair have size >1. Otherwise, returns true and appends
|
|
// the degerenate input/output dimensions in the gap to
|
|
// deleted_indices/inserted_indices respectively.
|
|
auto check_modified_dims =
|
|
[&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
|
|
std::pair<int64, int64> prior_unmodified_dim_pair,
|
|
std::pair<int64, int64> unmodified_dim_pair) {
|
|
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
|
|
modified_input_dim < unmodified_dim_pair.first;
|
|
++modified_input_dim) {
|
|
if (shape_pre.dimensions(modified_input_dim) > 1) {
|
|
return false;
|
|
}
|
|
deleted_indices.push_back(modified_input_dim);
|
|
}
|
|
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
|
|
modified_output_dim < unmodified_dim_pair.second;
|
|
++modified_output_dim) {
|
|
if (shape_post.dimensions(modified_output_dim) > 1) {
|
|
return false;
|
|
}
|
|
inserted_indices.push_back(modified_output_dim);
|
|
}
|
|
return true;
|
|
};
|
|
|
|
std::vector<std::pair<int64, int64>> unmodified_dims =
|
|
DimensionsUnmodifiedByReshape(shape_pre, shape_post);
|
|
// Returns nil if the reshape modifies any non-degenerate input/output
|
|
// dimension. DimensionsUnmodifiedByReshape gives us all unmodified
|
|
// dimensions, so we only need to check whether dimensions in the gaps (thus
|
|
// modified) have size >1.
|
|
for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
|
|
// Check (modified) dimensions between unmodified_dims[i-1] and
|
|
// unmodified_dims[i].
|
|
auto prior_unmodified_dim_pair =
|
|
i > 0 ? unmodified_dims[i - 1] : std::pair<int64, int64>(-1, -1);
|
|
auto unmodified_dim_pair =
|
|
i < unmodified_dims.size()
|
|
? unmodified_dims[i]
|
|
: std::make_pair(shape_pre.rank(), shape_post.rank());
|
|
if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
|
|
return nil;
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(true, deleted_indices, inserted_indices);
|
|
}
|
|
|
|
/* static */ std::vector<std::pair<int64, int64>>
|
|
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
|
const Shape& output_shape) {
|
|
CHECK(input_shape.IsArray());
|
|
CHECK(output_shape.IsArray());
|
|
|
|
// Unmodified dimensions are merely common factors of rank 1.
|
|
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
|
AsInt64Slice(output_shape.dimensions()));
|
|
for (size_t i = 0; i < common_factors.size() - 1;) {
|
|
if (1 != common_factors[i + 1].first - common_factors[i].first ||
|
|
1 != common_factors[i + 1].second - common_factors[i].second) {
|
|
common_factors.erase(common_factors.begin() + i);
|
|
} else {
|
|
++i;
|
|
}
|
|
}
|
|
// `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
|
|
common_factors.pop_back();
|
|
return std::vector<std::pair<int64, int64>>(common_factors.begin(),
|
|
common_factors.end());
|
|
}
|
|
|
|
/* static */ absl::optional<std::vector<int64>>
|
|
ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
|
const Shape& from_shape, const Shape& to_shape,
|
|
absl::Span<const int64> input_dim_indices) {
|
|
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
|
|
|
|
std::vector<int64> output_dim_indices;
|
|
std::vector<std::pair<int64, int64>> unmodified_dims =
|
|
ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
|
|
size_t i = 0; // index to unmodified_dims
|
|
for (int64 input_dim_index : input_dim_indices) {
|
|
// Search unmodified_dims for input_dim_index. We can search from the last
|
|
// matching position because input_dim_indices is guaranteed to be sorted.
|
|
while (i < unmodified_dims.size() &&
|
|
unmodified_dims[i].first < input_dim_index) {
|
|
++i;
|
|
}
|
|
if (i >= unmodified_dims.size() ||
|
|
unmodified_dims[i].first != input_dim_index) {
|
|
return absl::nullopt;
|
|
}
|
|
output_dim_indices.push_back(unmodified_dims[i].second);
|
|
}
|
|
return output_dim_indices;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
|
const Shape& input_shape, const Shape& output_shape,
|
|
absl::Span<const int64> dimension_mapping) {
|
|
CHECK(LayoutUtil::HasLayout(input_shape) &&
|
|
LayoutUtil::HasLayout(output_shape));
|
|
|
|
if (!SameElementType(input_shape, output_shape)) {
|
|
return false;
|
|
}
|
|
|
|
// Check the reshape permutes the positions of each dimension in the
|
|
// minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
|
|
// input_positions = apply(dimension_mapping, output_positions)
|
|
//
|
|
// Because the positions of each dimension are the inverse permutation of the
|
|
// minor-to-major order, the above check is equivalent to
|
|
// inverse(input_dimensions) =
|
|
// apply(dimension_mapping, inverse(output_dimensions))
|
|
// # `I` indicates identity permutation.
|
|
// apply(input_dimensions, I) =
|
|
// apply(dimension_mapping, apply(output_dimensions, I))
|
|
// apply(input_dimensions, I) =
|
|
// apply((dimension_mapping * output_dimensions), I)
|
|
// input_dimensions = dimension_mapping * output_dimensions
|
|
return absl::c_equal(
|
|
ComposePermutations(dimension_mapping,
|
|
AsInt64Slice(output_shape.layout().minor_to_major())),
|
|
input_shape.layout().minor_to_major());
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
|
|
const Shape& output_shape) {
|
|
CHECK(input_shape.IsArray());
|
|
CHECK(output_shape.IsArray());
|
|
CHECK(LayoutUtil::HasLayout(input_shape));
|
|
CHECK(LayoutUtil::HasLayout(output_shape));
|
|
|
|
if (!SameElementType(input_shape, output_shape)) {
|
|
return false;
|
|
}
|
|
|
|
CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape))
|
|
<< "input_shape=" << input_shape.ShortDebugString()
|
|
<< ", output_shape=" << output_shape.ShortDebugString();
|
|
if (ElementsIn(input_shape) == 0) {
|
|
return true;
|
|
}
|
|
|
|
// TL;DR: The rest of the method checks that the reshape does not change the
|
|
// physical location of any unit input or output index. Unit indices have
|
|
// exactly one dimension that equals 1 and other dimensions 0. This condition
|
|
// is necessary for the reshape to be a bitcast, because a bitcast-equivalent
|
|
// reshape shouldn't change the physical location of any element. It is also a
|
|
// sufficient condition as is proved below (note: many details are omitted for
|
|
// space).
|
|
//
|
|
// Definitions:
|
|
//
|
|
// * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
|
|
// the size of i-th least significant dimension of IS or OS (this is opposite
|
|
// to how we define the index of Shape::dimensions()).
|
|
//
|
|
// * Given an input or output index I, denote by p(I) I's physical linear
|
|
// index (or physical index for short) and l(I) I's logical linear index (or
|
|
// logical index for short).
|
|
//
|
|
// * Given a logical index k, denote by II(k) the input index whose linear
|
|
// index is k, and OI(k) the corresponding output index.
|
|
//
|
|
// * Denote by IT[i] the increment of physical index if i-th dimension of the
|
|
// input index is increased by 1. Similarly, OT[i] means the increment if i-th
|
|
// dimension of the output index is increased by 1. Note that IT[i] or OT[i]
|
|
// is a function of IS or OS and the layout, and not dependent on the specific
|
|
// input or output index.
|
|
//
|
|
// To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
|
|
// that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
|
|
// induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
|
|
// to prove, with every increment on k, the above formula still holds.
|
|
//
|
|
// First, suppose reshaping from IS to OS is non-factorizable (we discuss
|
|
// refactorizable reshapes later). A reshape from IS to OS is factorizable, if
|
|
// there exists (i,j) such that
|
|
//
|
|
// 0<=i<=|IS|
|
|
// 0<=j<=|OS|
|
|
// |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
|
|
// product(IS[i], IS[i+1], ..., IS[|IS|-1])
|
|
// = product(OS[j], OS[j+1], ..., OS[|OS|-1])
|
|
//
|
|
// p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
|
|
// are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
|
|
// unit indices which are already tested. This also means IT[0]=OT[0]
|
|
// because p(II(1))=IT[0] and p(OI(1))=OT[0].
|
|
//
|
|
// Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
|
|
// increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
|
|
// to the output physical.
|
|
//
|
|
// When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
|
|
// suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
|
|
// Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
|
|
// logical index k-1 to logical index k, dimension 1 of the input index
|
|
// is increased by 1 and dimension 0 is reset to 0 thus decreased by
|
|
// IS[0]-1. Therefore, the physical input index is increased by
|
|
//
|
|
// p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
|
|
//
|
|
// Because IS[0]<OS[0], the only change to the output index is that its
|
|
// dimension 0 is increased by one. Therefore,
|
|
//
|
|
// p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
|
|
//
|
|
// Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
|
|
// p(II(k))=p(OI(k)). Therefore,
|
|
// IT[1] - (IS[0]-1) * IT[0] = IT[0]
|
|
// IT[1] = IS[0] * IT[0]
|
|
// In other words, input dimension 1 is immediately more major than input
|
|
// dimension 0. We can now conceptually collapse these two dimensions because
|
|
// an increment in the logical index affecting only these two dimensions maps
|
|
// to IT[0] in the physical index.
|
|
//
|
|
// By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
|
|
// OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
|
|
// identical.
|
|
//
|
|
// A factorizable reshape can be factorized into a list of non-factorizable
|
|
// sub-reshapes, each of which can be handled similarly to the proof above.
|
|
// For example,
|
|
//
|
|
// [7x9x2x15] -> [63x6x5]
|
|
//
|
|
// can be factorized into
|
|
//
|
|
// [7x9] -> [63] and [2x15] -> [6x5].
|
|
//
|
|
// Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
|
|
// same logical linear index. According to the factorization, we know
|
|
// l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
|
|
// non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
|
|
// similar proof, with the increment of the logical index set to
|
|
// IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
|
|
// p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
|
|
//
|
|
// p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
|
|
// = p(y2,0,0) + p(0,0,y1,y0)
|
|
// = p(y2,y1,y0)
|
|
//
|
|
// check_input_unit_indices checks one way of the condition: each input unit
|
|
// index is mapped to an output index with the same physical location. This
|
|
// lambda will be called again with input_shape and output_shape reversed to
|
|
// check the other way.
|
|
auto check_input_unit_indices = [](const Shape& input_shape,
|
|
const Shape& output_shape) {
|
|
// input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
|
|
// as input_shape/output_shape and the dimension-0-major layout. These two
|
|
// shapes are used for conversion between logical linear indices and
|
|
// multi-dimensional indices.
|
|
Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
|
|
input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
|
|
Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
|
|
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
|
|
|
|
for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) {
|
|
if (input_shape.dimensions(input_dim) <= 1) {
|
|
continue;
|
|
}
|
|
|
|
std::vector<int64> input_unit_index(input_shape.rank(), 0);
|
|
input_unit_index[input_dim] = 1;
|
|
int64 logical_linear_index =
|
|
IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
|
|
input_unit_index);
|
|
// output_index has the same logical linear index as input_unit_index.
|
|
std::vector<int64> output_index =
|
|
IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
|
|
logical_linear_index);
|
|
// Check input_unit_index and output_index have the same physical linear
|
|
// index.
|
|
if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
|
|
input_unit_index) !=
|
|
IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
|
|
output_index)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
return check_input_unit_indices(input_shape, output_shape) &&
|
|
check_input_unit_indices(output_shape, input_shape);
|
|
}
|
|
|
|
/* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
|
|
const Shape& input_shape, const Shape& output_shape) {
|
|
CHECK(input_shape.IsArray());
|
|
CHECK(output_shape.IsArray());
|
|
// Removing trivial dimensions from the shape simplifies the alignment
|
|
// algorithm since ones can go in any position.
|
|
if (HasDegenerateDimensions(input_shape) ||
|
|
HasDegenerateDimensions(output_shape)) {
|
|
auto simple_output_shape =
|
|
AlignLayouts(DropDegenerateDimensions(input_shape),
|
|
DropDegenerateDimensions(output_shape));
|
|
if (!simple_output_shape) {
|
|
return absl::nullopt;
|
|
}
|
|
|
|
std::vector<int64> layout =
|
|
SpanToVector(simple_output_shape->layout().minor_to_major());
|
|
// For each one sized dimension in the output, increment the dimension
|
|
// numbers in layout that are more minor than the one.
|
|
absl::InlinedVector<int64, 8> dim_map;
|
|
dim_map.reserve(simple_output_shape->rank());
|
|
for (int64 i = 0; i < output_shape.rank(); ++i) {
|
|
if (output_shape.dimensions(i) != 1) {
|
|
dim_map.push_back(i);
|
|
}
|
|
}
|
|
for (int64& d : layout) {
|
|
d = dim_map[d];
|
|
}
|
|
|
|
// Add the ones in descending order to the layout. Descending layouts tend
|
|
// to reduce the number of copies inserted in layout assignment.
|
|
for (int64 i = output_shape.rank() - 1; i >= 0; --i) {
|
|
if (output_shape.dimensions(i) == 1) {
|
|
layout.push_back(i);
|
|
}
|
|
}
|
|
Shape output_shape_with_layout = output_shape;
|
|
*output_shape_with_layout.mutable_layout() = Layout{layout};
|
|
return output_shape_with_layout;
|
|
}
|
|
|
|
int64 input_rank = input_shape.rank();
|
|
int64 output_rank = output_shape.rank();
|
|
|
|
// First, calculate an alignment of the dimensions. A consecutive sequence of
|
|
// input dimensions and output dimensions belong to the same alignment part if
|
|
// the products of their dimension bounds are the same. In the easiest case,
|
|
// an alignment part consists of one input dimension and one output dimension
|
|
// which both have the same dimension bound. An alignment part specifies which
|
|
// dimensions need to be kept together in a physical layout if we want a
|
|
// reshape to be a bitcast. The order of the alignment parts is defined by the
|
|
// physical layout of the input shape, so when we construct the layout for the
|
|
// output shape we just process the alignment parts in this order, and then
|
|
// layout the dimensions belonging to each part in descending (major to minor)
|
|
// order.
|
|
|
|
// Stores the input and output dimension numbers where each alignment part
|
|
// starts.
|
|
std::vector<std::pair<int64, int64>> alignment;
|
|
alignment.push_back({0, 0});
|
|
|
|
// Stores a mapping from the input dimension to the alignment part it belongs
|
|
// to.
|
|
std::vector<int64> dimension_to_alignment_index(input_rank);
|
|
int64 input_dimension_product = 1, output_dimension_product = 1;
|
|
for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) {
|
|
// Check if we have reached the end of an alignment part.
|
|
if (input_dimension_product == output_dimension_product &&
|
|
input_dimension_product > 1) {
|
|
alignment.push_back({i, j});
|
|
input_dimension_product = output_dimension_product = 1;
|
|
}
|
|
if (input_dimension_product < output_dimension_product ||
|
|
j == output_rank) {
|
|
if (i == input_rank) {
|
|
return absl::nullopt;
|
|
}
|
|
dimension_to_alignment_index[i] = alignment.size() - 1;
|
|
input_dimension_product *= input_shape.dimensions(i);
|
|
++i;
|
|
} else {
|
|
output_dimension_product *= output_shape.dimensions(j);
|
|
++j;
|
|
}
|
|
}
|
|
if (input_dimension_product != output_dimension_product) {
|
|
return absl::nullopt;
|
|
}
|
|
|
|
// We also need to store an end element so that we know where the last
|
|
// alignment part ends.
|
|
alignment.push_back({input_rank, output_rank});
|
|
// Now check if the physical layout can potentially be aligned to the output
|
|
// shape by changing the physical layout of the output shape. We need to check
|
|
// that all dimension numbers that belong to the same alignment part appear
|
|
// consecutively, and are in descending order. However we can ignore any
|
|
// trivial dimension bounds of 1, because they can be placed anywhere.
|
|
auto input_dimension_numbers = input_shape.layout().minor_to_major();
|
|
std::vector<int64> output_layout;
|
|
output_layout.reserve(output_rank);
|
|
for (int64 i = 0; i < input_rank;) {
|
|
int64 current_dimension_number = input_dimension_numbers[i];
|
|
|
|
// Trivial dimensions are stripped.
|
|
CHECK_NE(input_shape.dimensions(current_dimension_number), 1);
|
|
const int64 current_alignment_index =
|
|
dimension_to_alignment_index[current_dimension_number];
|
|
// Because of the special end element that we added, we can be sure that
|
|
// 'current_alignment_index' is < alignment.size() - 1.
|
|
CHECK_LT(current_alignment_index, alignment.size() - 1);
|
|
|
|
// Check that the following 'num_non_trivial_dimensions_in_alignment_part'
|
|
// dimension numbers (ignoring dimension numbers with dimension bound 1) are
|
|
// in descending order and belong to the current alignment part.
|
|
for (int64 j = 0; j < alignment[current_alignment_index + 1].first -
|
|
alignment[current_alignment_index].first;
|
|
++i, ++j) {
|
|
if (i == input_rank) {
|
|
return absl::nullopt;
|
|
}
|
|
// If the current dimension number belongs to a different alignment part,
|
|
// or the dimension numbers are not in descending order, we can return
|
|
// early.
|
|
if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
|
|
current_alignment_index ||
|
|
input_dimension_numbers[i] > current_dimension_number) {
|
|
return absl::nullopt;
|
|
}
|
|
current_dimension_number = input_dimension_numbers[i];
|
|
}
|
|
// The output dimension numbers that belong to the current alignment part
|
|
// need to appear in the same descending order as in the input.
|
|
for (int64 j = alignment[current_alignment_index + 1].second - 1;
|
|
j >= alignment[current_alignment_index].second; --j) {
|
|
output_layout.push_back(j);
|
|
}
|
|
}
|
|
CHECK_EQ(output_layout.size(), output_rank);
|
|
Shape output_shape_with_layout = MakeShapeWithLayout(
|
|
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
|
|
output_layout);
|
|
CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout))
|
|
<< "reshape is not a bitcast for input_shape: "
|
|
<< ShapeUtil::HumanStringWithLayout(input_shape)
|
|
<< " and output_shape_with_layout: "
|
|
<< ShapeUtil::HumanStringWithLayout(output_shape_with_layout);
|
|
return output_shape_with_layout;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
|
|
Shape shape) {
|
|
CHECK(shape.IsArray());
|
|
shape.DeleteDimension(dim_to_delete);
|
|
return shape;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible(
|
|
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
|
|
if (dynamic_shape.rank() != bounded_shape.rank()) {
|
|
return false;
|
|
}
|
|
for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
|
|
if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/* static */ bool ShapeUtil::DynamicShapeIsCompatible(
|
|
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
|
|
bool compatible = true;
|
|
xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape,
|
|
const ShapeIndex& index) {
|
|
if (compatible) {
|
|
auto subshape_result = TryGetSubshape(bounded_shape, index);
|
|
if (subshape_result.ok()) {
|
|
const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie();
|
|
if (sub_shape.IsTuple()) {
|
|
if (!bounded_sub_shape->IsTuple()) {
|
|
compatible = false;
|
|
}
|
|
} else {
|
|
if (bounded_sub_shape->IsTuple()) {
|
|
compatible = false;
|
|
} else if (!sub_shape.is_static() &&
|
|
!DynamicArrayShapeIsCompatible(sub_shape,
|
|
*bounded_sub_shape)) {
|
|
compatible = false;
|
|
}
|
|
}
|
|
} else {
|
|
compatible = false;
|
|
}
|
|
}
|
|
});
|
|
return compatible;
|
|
}
|
|
|
|
/* static */ Shape ShapeUtil::FilterDimensions(
|
|
const std::function<bool(int64)>& p, Shape shape) {
|
|
CHECK(shape.IsArray());
|
|
std::vector<int64> dims_to_delete;
|
|
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
|
|
if (!p(i)) {
|
|
dims_to_delete.push_back(i);
|
|
}
|
|
}
|
|
for (int64 dim : dims_to_delete) {
|
|
shape = DeleteDimension(dim, shape);
|
|
}
|
|
return shape;
|
|
}
|
|
|
|
/*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
|
|
using tensorflow::hash;
|
|
using tensorflow::Hash64Combine;
|
|
|
|
size_t hash_value = hash<PrimitiveType>()(shape.element_type());
|
|
|
|
if (shape.tuple_shapes().empty()) {
|
|
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
|
hash_value =
|
|
Hash64Combine(hash_value, hash<int64>()(shape.dimensions(i)));
|
|
hash_value = Hash64Combine(hash_value,
|
|
hash<bool>()(shape.is_dynamic_dimension(i)));
|
|
}
|
|
|
|
hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
|
|
} else {
|
|
hash_value = 0;
|
|
for (const Shape& subshape : shape.tuple_shapes()) {
|
|
hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
|
|
}
|
|
}
|
|
|
|
return hash_value;
|
|
}
|
|
|
|
// Returns the indices of the first elements of all consecutive subarrays of the
|
|
// given array. For example:
|
|
// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
|
|
static std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
|
|
std::vector<size_t> is = {0};
|
|
for (size_t i = 1; i < xs.size(); ++i) {
|
|
if (1 != xs[i] - xs[i - 1]) {
|
|
is.push_back(i);
|
|
}
|
|
}
|
|
return is;
|
|
}
|
|
|
|
// Merges the sequences of dimensions of the given shape which start at the
|
|
// given indices `segs`.
|
|
static Shape MergeDimensions(absl::Span<const size_t> segs,
|
|
const Shape& shape) {
|
|
std::vector<int64> dimensions;
|
|
for (size_t i = 1; i <= segs.size(); ++i) {
|
|
dimensions.push_back(std::accumulate(
|
|
shape.dimensions().begin() + segs[i - 1],
|
|
shape.dimensions().begin() +
|
|
(segs.size() == i ? shape.dimensions().size() : segs[i]),
|
|
1, std::multiplies<int64>()));
|
|
}
|
|
return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
|
|
dimensions);
|
|
}
|
|
|
|
/*static*/ absl::optional<std::vector<int64>> ShapeUtil::FindTranspose021(
|
|
const Shape& a, const Shape& b) {
|
|
if (!CompatibleIgnoringElementType(a, b)) {
|
|
return absl::nullopt;
|
|
}
|
|
|
|
std::vector<int64> permutation(a.dimensions().size());
|
|
absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
|
|
std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
|
|
minor_to_major_a.rend());
|
|
absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
|
|
std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
|
|
minor_to_major_b.rend());
|
|
for (size_t i = 0; i < permutation.size(); ++i) {
|
|
permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
|
|
}
|
|
|
|
std::vector<size_t> segments = ConsecutiveSegments(permutation);
|
|
if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
|
|
Shape descending_layout_shape =
|
|
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
|
|
Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
|
|
absl::Span<const int64> normalized_dims =
|
|
AsInt64Slice(normalized_shape.dimensions());
|
|
std::vector<int64> dims_021;
|
|
if (2 == segments.size()) {
|
|
// The logical component-0 is of size one.
|
|
dims_021 = {1, normalized_dims[1], normalized_dims[0]};
|
|
} else {
|
|
dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
|
|
}
|
|
|
|
return dims_021;
|
|
}
|
|
|
|
return absl::nullopt;
|
|
}
|
|
|
|
Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
|
|
ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
|
|
if (subshape->IsArray()) {
|
|
subshape->mutable_layout()->clear_tiles();
|
|
subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
|
|
}
|
|
});
|
|
return s;
|
|
}
|
|
|
|
/*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
|
|
const Shape& to) {
|
|
return ElementIsFloating(from) == ElementIsFloating(to) &&
|
|
ElementIsSigned(from) == ElementIsSigned(to) &&
|
|
HigherPrecisionElementType(from, to) == to.element_type();
|
|
}
|
|
|
|
} // namespace xla
|