Replace calls to ShapeUtil::Rank with Shape::rank.

No functional change. ShapeUtil::Rank is marked as deprecated. A later CL will remove it.

PiperOrigin-RevId: 226412117
This commit is contained in:
Mark Heffernan 2018-12-20 16:24:50 -08:00 committed by TensorFlower Gardener
parent e7bac9435d
commit 1ed59e52b1
65 changed files with 335 additions and 347 deletions

View File

@ -129,7 +129,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
std::vector<string> dim_vars;
string dim_sizes, indices;
if (xla::ShapeUtil::Rank(shape) == 0 ||
if (shape.rank() == 0 ||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
dim_sizes = "[1]";
indices = "[0]";

View File

@ -39,7 +39,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
OP_REQUIRES(
ctx,
xla::ShapeUtil::Rank(crops.shape()) == 2 &&
crops.shape().rank() == 2 &&
block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
errors::InvalidArgument("crops should have shape [", block_rank,

View File

@ -110,8 +110,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
auto shape_status = b.GetShape(arg);
OP_REQUIRES_OK(ctx, shape_status.status());
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
*arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout(
xla::ShapeUtil::Rank(arg_shape));
*arg_shape.mutable_layout() =
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
arg_shapes.push_back(std::move(arg_shape));
}

View File

@ -38,8 +38,7 @@ class MirrorPadOp : public XlaOpKernel {
// - [1, 2, 3, 3, 2] in symmetric mode.
int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0;
xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
for (int64 dimno = original_shape.rank() - 1; dimno >= 0; --dimno) {
auto t_rev = xla::Rev(accum, {dimno});
int64 lhs_padding = pad_literal.Get<int64>({dimno, 0});
int64 rhs_padding = pad_literal.Get<int64>({dimno, 1});

View File

@ -39,7 +39,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
OP_REQUIRES(
ctx,
xla::ShapeUtil::Rank(paddings.shape()) == 2 &&
paddings.shape().rank() == 2 &&
block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
errors::InvalidArgument("paddings should have shape [", block_rank,

View File

@ -48,7 +48,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
if (indices_are_vectors) {
TF_RET_CHECK(!indices_dims.empty());
num_index_dims = indices_dims.back();
if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) {
if (num_index_dims > buffer_shape.rank()) {
return errors::InvalidArgument(
"The size of the minor dimension of the indices (shape: ",
xla::ShapeUtil::HumanString(indices_shape),
@ -140,8 +140,8 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
? indices_shape.dimensions_size() - 1
: indices_shape.dimensions_size());
int64 updates_rank = xla::ShapeUtil::Rank(updates_shape);
int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape);
int64 updates_rank = updates_shape.rank();
int64 buffer_rank = buffer_shape.rank();
int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
// If the rank of `updates` is 0 and does not match the expected rank of
@ -156,7 +156,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
if (updates_rank == 0 && expected_updates_rank != 0) {
new_updates = xla::Broadcast(updates, expected_updates_dims);
TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
updates_rank = xla::ShapeUtil::Rank(updates_shape);
updates_rank = updates_shape.rank();
}
if (updates_rank > 0) {

View File

@ -39,7 +39,7 @@ Status PopulateInfeedLayoutVector(const xla::Shape& shape,
layouts->push_back(dim);
}
} else {
layouts->insert(layouts->end(), xla::ShapeUtil::Rank(shape), -1);
layouts->insert(layouts->end(), shape.rank(), -1);
}
return Status::OK();
}
@ -55,7 +55,7 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
" cannot be converted to a TensorShape");
}
*tensor_shape = TensorShape();
for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) {
for (int i = 0; i < shape.rank(); ++i) {
tensor_shape->AddDim(shape.dimensions(i));
}
return Status::OK();

View File

@ -55,7 +55,7 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
xla::XlaOp input_max = xla::Reduce(input, init_value, reducer,
/*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1);
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.

View File

@ -178,7 +178,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
// Converts an int32 or int64 scalar literal to an int64.
static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
int64* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
if (literal.shape().rank() != 0) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::S32) {
@ -194,7 +194,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
// Converts an float32 or float64 scalar literal to a float64.
static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
double* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
if (literal.shape().rank() != 0) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::F32) {
@ -228,7 +228,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
// Converts an int32 or int64 1D literal to an int64 vector.
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64>* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
if (literal.shape().rank() != 1) {
return errors::InvalidArgument("value is not 1D");
}
int64 size = xla::ShapeUtil::ElementsIn(literal.shape());

View File

@ -117,7 +117,7 @@ XlaOp Any(XlaOp predicates) {
XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
builder->GetShape(predicates));
std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
std::vector<int64> all_dimensions(predicates_shape.rank());
std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
return Reduce(predicates, f, logical_or, all_dimensions);
});

View File

@ -54,7 +54,7 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int n_dims = ShapeUtil::Rank(a_shape);
const int n_dims = a_shape.rank();
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
auto major_dims = AsInt64Slice(a_shape.dimensions())
.subspan(
@ -144,7 +144,7 @@ XlaOp Cholesky(XlaOp a, int64 block_size,
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int ndims = ShapeUtil::Rank(a_shape);
const int ndims = a_shape.rank();
if (ndims < 2) {
return InvalidArgument(
"Argument to Cholesky must have rank >= 2; shape was %s",

View File

@ -41,7 +41,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
@ -68,7 +68,7 @@ XlaOp TriangleMask(XlaOp x, int diagonal) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
@ -99,12 +99,12 @@ XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
// Check that both tensors have the same number of dimensions. There must be
// at least two (the batch dimensions can be empty).
if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) {
if (x_shape.rank() != y_shape.rank()) {
return InvalidArgument(
"Arguments to BatchDot have different ranks: %s vs. %s",
ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape));
}
const int ndims = ShapeUtil::Rank(x_shape);
const int ndims = x_shape.rank();
if (ndims < 2) {
return InvalidArgument(
"Arguments to BatchDot must have rank >= 2: got %d", ndims);
@ -169,7 +169,7 @@ XlaOp TransposeInMinorDims(XlaOp x) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);

View File

@ -154,7 +154,7 @@ struct QRBlockResult {
StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = ShapeUtil::Rank(a_shape);
const int num_dims = a_shape.rank();
if (num_dims < 2) {
return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
a_shape.ToString());
@ -325,7 +325,7 @@ StatusOr<QRDecompositionResult> QRDecomposition(
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = ShapeUtil::Rank(a_shape);
const int num_dims = a_shape.rank();
if (num_dims < 2) {
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
a_shape.ToString());

View File

@ -26,7 +26,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
TF_RET_CHECK(n_minor_dims <= n_dims);
auto major_dims = AsInt64Slice(shape.dimensions())
.subspan(
@ -55,7 +55,7 @@ XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
std::vector<int32> start_as_int32(start.begin(), start.end());
auto start_constant = ConstantR1<int32>(builder, start_as_int32);
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
TF_ASSIGN_OR_RETURN(Shape start_constant_shape,
builder->GetShape(start_constant));
const int64 start_length =
@ -70,7 +70,7 @@ XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
const int64 n_minor_dims = start.size();
TF_RET_CHECK(n_minor_dims <= n_dims);
std::vector<int64> padded_start(n_dims, 0);
@ -94,7 +94,7 @@ XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span<const XlaOp> starts) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
auto zero = Reshape(ConstantR0<int32>(builder, 0), {1});
std::vector<XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
@ -111,7 +111,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_dims = shape.rank();
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);

View File

@ -38,7 +38,7 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
int ndims = ShapeUtil::Rank(shape);
int ndims = shape.rank();
int64 n = ShapeUtil::GetDimension(shape, -1);
int64 num_blocks = n / block_size;
@ -262,7 +262,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
int64 ndims = ShapeUtil::Rank(a_shape);
int64 ndims = a_shape.rank();
int64 n = ShapeUtil::GetDimension(a_shape, -1);
int64 num_blocks = n / block_size + (n % block_size != 0);
int64 m_dim = (left_side) ? -1 : -2;
@ -356,13 +356,13 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) {
if (a_shape.rank() != b_shape.rank()) {
return InvalidArgument(
"Arguments to TriangularSolve have shapes with different ranks: "
"%s vs. %s",
ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
}
const int64 ndims = ShapeUtil::Rank(a_shape);
const int64 ndims = a_shape.rank();
if (ndims < 2) {
return InvalidArgument(
"Arguments to TriangularSolve was rank %d but must have rank >= 2.",

View File

@ -50,7 +50,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
CHECK_EQ(tile_shape.rank(), 1);
std::vector<int64> dimensions(1, num_tiles);
*result.mutable_tile_shape() = tile_shape.ToProto();
auto& tile_dimension =

View File

@ -343,7 +343,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
CHECK(ShapeUtil::IsScalar(operand_shape) ||
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape));
operand_shape.rank() == output_shape.rank());
Shape broadcast_shape =
ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type());
@ -355,7 +355,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
// Do explicit broadcast for degenerate broadcast.
std::vector<int64> broadcast_dimensions;
std::vector<int64> reshaped_dimensions;
for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) {
for (int i = 0; i < operand_shape.rank(); i++) {
if (operand_shape.dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand_shape.dimensions(i));
@ -398,8 +398,8 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
binop, lhs_shape, rhs_shape, broadcast_dimensions));
*instr.mutable_shape() = shape.ToProto();
const int64 lhs_rank = ShapeUtil::Rank(lhs_shape);
const int64 rhs_rank = ShapeUtil::Rank(rhs_shape);
const int64 lhs_rank = lhs_shape.rank();
const int64 rhs_rank = rhs_shape.rank();
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
@ -413,8 +413,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
for (int64 size : shape.dimensions()) {
to_size.push_back(size);
}
for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape);
from_dim++) {
for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
int64 to_dim = broadcast_dimensions[from_dim];
to_size[to_dim] = from_shape.dimensions(from_dim);
}
@ -563,10 +562,10 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
// output, so to append dimensions on the left the instruction's dimensions
// should just be the n highest dimension numbers of the output shape where
// n is the number of input dimensions.
const int64 operand_rank = ShapeUtil::Rank(operand_shape);
const int64 operand_rank = operand_shape.rank();
std::vector<int64> dimensions(operand_rank);
for (int i = 0; i < operand_rank; ++i) {
dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank;
dimensions[i] = i + shape.rank() - operand_rank;
}
return InDimBroadcast(shape, operand, dimensions);
});
@ -639,10 +638,10 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
std::vector<int64> starts(shape.rank(), 0);
std::vector<int64> limits(shape.dimensions().begin(),
shape.dimensions().end());
std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
std::vector<int64> strides(shape.rank(), 1);
starts[dimno] = start_index;
limits[dimno] = limit_index;
strides[dimno] = stride;
@ -780,7 +779,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
std::vector<int64> new_sizes;
for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
for (int i = 0; i < original_shape.rank(); ++i) {
if (i <= dimensions.front() || i > dimensions.back()) {
new_sizes.push_back(original_shape.dimensions(i));
} else {
@ -915,13 +914,13 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
Status XlaBuilder::VerifyConvolution(
const Shape& lhs_shape, const Shape& rhs_shape,
const ConvolutionDimensionNumbers& dimension_numbers) const {
if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
if (lhs_shape.rank() != rhs_shape.rank()) {
return InvalidArgument(
"Convolution arguments must have same number of "
"dimensions. Got: %s and %s",
ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
}
int num_dims = ShapeUtil::Rank(lhs_shape);
int num_dims = lhs_shape.rank();
if (num_dims < 2) {
return InvalidArgument(
"Convolution expects argument arrays with >= 3 dimensions. "
@ -1582,7 +1581,7 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
*instr.mutable_shape() = shape.ToProto();
if (dimension == -1) {
TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
dimension = ShapeUtil::Rank(keys_shape) - 1;
dimension = keys_shape.rank() - 1;
}
instr.add_dimensions(dimension);
std::vector<XlaOp> operands{keys};
@ -1652,12 +1651,12 @@ XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
*instr.mutable_shape() = shape.ToProto();
Shape output_shape(instr.shape());
const int64 output_rank = ShapeUtil::Rank(output_shape);
const int64 output_rank = output_shape.rank();
AddCalledComputation(computation, &instr);
std::vector<XlaOp> new_operands(operands.begin(), operands.end());
for (XlaOp& new_operand : new_operands) {
TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand));
const int64 rank = ShapeUtil::Rank(shape);
const int64 rank = shape.rank();
if (rank != output_rank) {
TF_ASSIGN_OR_RETURN(new_operand,
InDimBroadcast(output_shape, new_operand, {}));
@ -1866,7 +1865,7 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
std::vector<int64> all_dimnos(ShapeUtil::Rank(operand_shape));
std::vector<int64> all_dimnos(operand_shape.rank());
std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
return Reduce(operand, init_value, computation, all_dimnos);
});

View File

@ -141,7 +141,7 @@ namespace xla {
/* static */ bool IndexUtil::IndexInBounds(const Shape& shape,
absl::Span<const int64> index) {
int64 rank = ShapeUtil::Rank(shape);
int64 rank = shape.rank();
if (rank != index.size()) {
return false;
}

View File

@ -211,19 +211,19 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
if (layout.format() == DENSE) {
if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) {
if (layout.minor_to_major_size() != shape.rank()) {
return InvalidArgument(
"layout minor_to_major field contains %d elements, "
"but shape is rank %d: {%s}; shape: %s",
layout.minor_to_major_size(), ShapeUtil::Rank(shape),
layout.minor_to_major_size(), shape.rank(),
absl::StrJoin(layout.minor_to_major(), ", "),
shape.ShortDebugString());
}
std::vector<bool> dimensions_in_layout(ShapeUtil::Rank(shape), false);
for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) {
std::vector<bool> dimensions_in_layout(shape.rank(), false);
for (int64 i = 0; i < shape.rank(); ++i) {
int64 dim = layout.minor_to_major(i);
if (dim < 0 || dim >= ShapeUtil::Rank(shape)) {
if (dim < 0 || dim >= shape.rank()) {
return InvalidArgument(
"layout minor_to_major field has out-of-bounds value: %s",
HumanString(layout));
@ -376,7 +376,7 @@ Status CopyLayoutInternal(const Shape& src, Shape* dst) {
}
} else {
if (src.has_layout()) {
if (ShapeUtil::Rank(src) != ShapeUtil::Rank(*dst)) {
if (src.rank() != dst->rank()) {
return InvalidArgument("cannot copy layout from shape: ranks differs");
}
TF_RETURN_IF_ERROR(
@ -410,7 +410,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
return true;
} else if (ShapeUtil::IsArray(lhs)) {
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
return lhs.rank() == rhs.rank() &&
LayoutUtil::Equal(lhs.layout(), rhs.layout());
} else {
// Layouts of non-array and non-tuple shapes is ignored.

View File

@ -129,7 +129,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
new char[max_sparse_elements *
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
piece->set_sparse_indices(
new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
new SparseIndexArray(max_sparse_elements, shape.rank()));
} else {
piece->set_buffer(new char[piece->size_bytes()]);
}
@ -208,16 +208,15 @@ template <typename NativeT>
Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, absl::Span<const int64> src_base,
absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
TF_RET_CHECK(src_literal.shape().rank() == src_base.size());
TF_RET_CHECK(shape().rank() == dest_base.size());
auto linear_index = [](const Shape& shape,
absl::Span<const int64> multi_index) {
return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
};
if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
ShapeUtil::Rank(shape()) == 0) {
if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
// If any of the two shapes are scalars, we can just call the StridedCopy()
// directly, and we know we will be copying only one value.
TF_RET_CHECK(copy_size.empty());
@ -375,7 +374,7 @@ void CopyElementsBetween(absl::Span<NativeT> dest,
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(ShapeUtil::Rank(dest_shape));
std::vector<int64> index(dest_shape.rank());
do {
dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
@ -392,7 +391,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
memcpy(buffer(), src.buffer(), src.size_bytes());
} else {
TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
std::vector<int64> origin(subshape().rank(), 0);
switch (subshape().element_type()) {
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
case (XLA_T): \
@ -563,7 +562,7 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(shape().rank(), 1);
CHECK_EQ(element_count(), values.bits());
CHECK_EQ(shape().element_type(), PRED);
for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
@ -648,8 +647,7 @@ StatusOr<Literal> LiteralBase::Reshape(
}
Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
} else {
output = Clone();
}
@ -672,7 +670,7 @@ StatusOr<Literal> LiteralBase::Reshape(
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
CHECK(IsPermutation(permutation, shape().rank()))
<< "Given permutation is not a permutation of dimension numbers";
// To transpose the array, we just permute the dimensions and layout, and
// do a straight memory copy of the raw data set.
@ -711,10 +709,10 @@ template <typename NativeT>
Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
DimensionVector new_indices(result_shape.rank());
result_literal.EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
for (int64 i = 0; i < result_shape.rank(); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
@ -728,7 +726,7 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
CHECK_GE(start_indices[dnum], 0);
CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
<< "dnum = " << dnum;
@ -1056,7 +1054,7 @@ void SparseArrayToStringHelper(const LiteralBase& literal,
pieces->push_back(ShapeToString(print_layout, subshape));
}
pieces->push_back("{");
int64 rank = ShapeUtil::Rank(subshape);
int64 rank = subshape.rank();
int64 num_elements = literal.sparse_element_count();
for (int64 i = 0; i < num_elements; ++i) {
if (i > 0) {
@ -1079,7 +1077,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
const ShapeIndex& shape_index, bool print_shape,
bool print_layout, std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
int64 rank = ShapeUtil::Rank(subshape);
int64 rank = subshape.rank();
std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
to_string_recursive = [&](absl::Span<const int64> dimensions,
@ -1433,7 +1431,7 @@ StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
template <typename NativeT>
bool LiteralBase::Piece::EqualElementsInternal(
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
if (multi_index->size() == ShapeUtil::Rank(subshape())) {
if (multi_index->size() == subshape().rank()) {
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
@ -1722,7 +1720,7 @@ bool LiteralBase::IsR1Iota() const {
return false;
}
if (ShapeUtil::Rank(shape()) != 1) {
if (shape().rank() != 1) {
return false;
}
@ -1932,14 +1930,12 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
if (LayoutUtil::IsSparseArray(subshape())) {
// Compute the number of elements (indices) in the sparse shape and reserve
// the necessary space in spare_indices.
TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0)
<< "Scalar shapes cannot be sparse";
TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0)
TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse";
TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0)
<< "Unexpected number of indices in proto ("
<< proto.sparse_indices_size() << ") for shape of rank "
<< ShapeUtil::Rank(subshape());
const int64 index_count =
proto.sparse_indices_size() / ShapeUtil::Rank(subshape());
<< subshape().rank();
const int64 index_count = proto.sparse_indices_size() / subshape().rank();
sparse_indices()->Resize(index_count);
// Copy the indices from the proto into the SparseIndexArray object.
@ -2065,7 +2061,7 @@ int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
string LiteralBase::GetR1U8AsString() const {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(shape().rank(), 1);
CHECK_EQ(shape().element_type(), U8);
return string(absl::bit_cast<const char*>(data<uint8>().data()),
ShapeUtil::ElementsIn(shape()));

View File

@ -961,7 +961,7 @@ void MutableLiteralBase::AppendSparseElement(
Piece& p = piece(shape_index);
const Shape& subshape = p.subshape();
CHECK(LayoutUtil::IsSparseArray(subshape));
int64 rank = ShapeUtil::Rank(subshape);
int64 rank = subshape.rank();
CHECK_EQ(multi_index.size(), rank);
int64 last_element = p.sparse_indices()->index_count();
CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
@ -977,7 +977,7 @@ void LiteralBase::EachCell(
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
std::vector<int64> indices(shape().rank(), 0);
do {
per_cell(indices, Get<NativeT>(indices));
} while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
@ -986,7 +986,7 @@ void LiteralBase::EachCell(
template <typename NativeT>
inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(shape().rank(), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
@ -998,7 +998,7 @@ template <typename NativeT>
void MutableLiteralBase::PopulateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 2);
CHECK_EQ(shape().rank(), 2);
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
@ -1024,7 +1024,7 @@ void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
CHECK_EQ(shape().rank(), values.num_dimensions());
for (int dim = 0; dim < values.num_dimensions(); ++dim) {
CHECK_EQ(values.dim(dim), shape().dimensions(dim));
}
@ -1053,7 +1053,7 @@ void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
absl::Span<const NativeT> values,
bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
int rank = shape().rank();
CHECK_EQ(indices.rank(), rank);
int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
CHECK_LE(indices.max_indices(), max_elements);
@ -1077,7 +1077,7 @@ template <typename NativeT, typename FnType>
Status MutableLiteralBase::PopulateInternal(const FnType& generator,
bool parallel) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
const int64 rank = this_shape.rank();
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());

View File

@ -463,7 +463,7 @@ class NearComparator {
}
return;
}
std::vector<int64> multi_index(ShapeUtil::Rank(actual_.shape()), 0);
std::vector<int64> multi_index(actual_.shape().rank(), 0);
CompareLiteralsSlow(0, &multi_index);
}
@ -777,7 +777,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
}
}
} else if (ShapeUtil::IsArray(expected)) {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
if (expected.rank() != actual.rank()) {
return InvalidArgument("want rank of %s got rank of %s",
ShapeUtil::HumanString(expected),
ShapeUtil::HumanString(actual));

View File

@ -132,7 +132,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
}
} else {
int rank = ShapeUtil::Rank(shape);
int rank = shape.rank();
dimensions = PyTuple_New(rank);
for (int i = 0; i < rank; ++i) {
PyTuple_SET_ITEM(dimensions, i,
@ -354,7 +354,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
}
return tuple;
} else {
int rank = ShapeUtil::Rank(literal.shape());
int rank = literal.shape().rank();
std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long*
for (int i = 0; i < rank; i++) {
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);

View File

@ -552,7 +552,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
CHECK_EQ(result_literal.shape().rank(), 4);
auto result =
absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
result_literal.shape().dimensions(1),

View File

@ -251,7 +251,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Reshapes an instruction to rank 1 if it is not already rank 1.
HloInstruction* Flatten(HloInstruction* hlo) {
if (ShapeUtil::Rank(hlo->shape()) == 1) {
if (hlo->shape().rank() == 1) {
return hlo;
}
return computation_->AddInstruction(HloInstruction::CreateReshape(
@ -687,7 +687,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
return Status::OK();
}
PaddingConfig padding_config;
for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) {
for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) {
auto padding_config_dim = padding_config.add_dimensions();
padding_config_dim->set_edge_padding_high(0);
padding_config_dim->set_edge_padding_low(0);
@ -754,7 +754,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
}
// If a literal is an increasing sequence from zero, replace it with an iota.
if (ShapeUtil::Rank(constant->shape()) == 1 &&
if (constant->shape().rank() == 1 &&
ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsR1Iota()) {
return ReplaceWithNewInstruction(
@ -930,9 +930,9 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
return -1;
};
const int64 dot_rank = ShapeUtil::Rank(dot->shape());
const int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
const int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
const int64 dot_rank = dot->shape().rank();
const int64 rhs_rank = rhs->shape().rank();
const int64 lhs_rank = lhs->shape().rank();
const auto& dnums = dot->dot_dimension_numbers();
if (dnums.rhs_contracting_dimensions_size() > 1) {
return false;
@ -1036,7 +1036,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
// )
if (lhs_rank == 1 ||
(lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) {
if (ShapeUtil::Rank(rhs->shape()) == 1) {
if (rhs->shape().rank() == 1) {
TF_RETURN_IF_ERROR(
ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
multiply(Flatten(lhs), rhs), 0))));
@ -1449,8 +1449,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
dot->shape().element_type() != BF16) {
return Status::OK();
}
if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 ||
ShapeUtil::Rank(dot->shape()) > 2) {
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
dot->shape().rank() > 2) {
if (options_.enable_dot_strength_reduction() &&
!options_.is_layout_sensitive()) {
TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status());
@ -1732,8 +1732,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
// A degenerate broadcast that has the same input and output rank can be
// converted into a transpose.
if (ShapeUtil::Rank(broadcast->shape()) ==
ShapeUtil::Rank(operand->shape()) &&
if (broadcast->shape().rank() == operand->shape().rank() &&
ShapeUtil::ElementsIn(broadcast->shape()) ==
ShapeUtil::ElementsIn(operand->shape())) {
VLOG(10) << "transform broadcast(X) -> transpose(X) where "
@ -1888,7 +1887,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
if (HasInteriorPadding(pad->padding_config())) {
PaddingConfig padding_config = pad->padding_config();
bool cleared_interior_padding = false;
for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
for (int64 i = 0; i < pad->shape().rank(); ++i) {
if (padding_config.dimensions(i).interior_padding() > 0 &&
pad->operand(0)->shape().dimensions(i) == 1) {
cleared_interior_padding = true;
@ -2276,7 +2275,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
VLOG(10) << "Trying to simplify scalar slice of concat";
// Only do this for R1, there's no chance of this being useful otherwise.
if (ShapeUtil::Rank(slice->shape()) != 1) {
if (slice->shape().rank() != 1) {
VLOG(10) << "Not folding, slice is not rank 1";
return false;
}
@ -2326,7 +2325,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
return false;
}
HloInstruction* new_slice_operand = reshape->mutable_operand(0);
int64 slice_rank = ShapeUtil::Rank(slice->shape());
int64 slice_rank = slice->shape().rank();
std::vector<int64> sliced_dims;
for (int64 i = 0; i < slice_rank; ++i) {
if (slice->slice_starts(i) != 0 ||
@ -2338,7 +2337,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
slice->slice_starts(0) == 0) {
const Shape& new_slice_shape = new_slice_operand->shape();
const int64 rank = ShapeUtil::Rank(new_slice_shape);
const int64 rank = new_slice_shape.rank();
std::vector<int64> new_slice_starts(rank, 0);
std::vector<int64> new_slice_stides(rank, 1);
std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
@ -2456,8 +2455,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// A Transpose feeding a reduce can simply permute the reduction dimensions
// field if the output of the reduce is a vector or scalar. Higher ranked
// result may require a transpose of the output.
if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
arg->opcode() == HloOpcode::kTranspose) {
if (reduce->shape().rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) {
auto transpose_dimensions = arg->dimensions();
std::vector<int64> new_reduce_dimensions;
for (auto dim : dimensions) {
@ -2516,8 +2514,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
arg->shape());
std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true);
std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false);
std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
for (auto dim : dimensions) {
arg_dim_in_output[dim] = false;
}
@ -2542,7 +2540,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
}
}
std::vector<int64> new_reduce_dimensions;
for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) {
for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
if (dimensions_not_to_reduce.count(i) == 0) {
new_reduce_dimensions.push_back(i);
}
@ -2779,7 +2777,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
// Carry out the folding of the pad into reduce_window.
VLOG(10) << "Folding pad into reduce-window.";
Window new_window = window;
const int64 rank = ShapeUtil::Rank(reduce_window->shape());
const int64 rank = reduce_window->shape().rank();
TF_RET_CHECK(pad_config.dimensions_size() == rank);
TF_RET_CHECK(window.dimensions_size() == rank);
for (int64 i = 0; i < rank; ++i) {
@ -2862,7 +2860,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
// - Use this as the indices parameter of scatter, and set updates
// of the scatter to be a reshaped 'values' parameter of sort (adding
// 'rank' many 1 dimensions at the end).
int64 rank = ShapeUtil::Rank(operand->shape());
int64 rank = operand->shape().rank();
Shape extended_shape = operand->shape();
extended_shape.add_dimensions(1);
extended_shape.mutable_layout()->add_minor_to_major(rank);

View File

@ -3498,7 +3498,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Create the reduce-window.
Window window;
for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
for (int64 i = 0; i < pad->shape().rank(); ++i) {
auto* dim = window.add_dimensions();
dim->set_size(1);
dim->set_padding_low(10);
@ -3584,7 +3584,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
// Create the reduce-window.
Window window;
for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
for (int64 i = 0; i < pad->shape().rank(); ++i) {
auto* dim = window.add_dimensions();
dim->set_size(1);
dim->set_padding_low(10);

View File

@ -123,7 +123,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
auto elements_per_feature_u32 = add_instruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
for (int64 i = 0; i < operand->shape().rank(); ++i) {
if (i == feature_index) {
continue;
}
@ -229,7 +229,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
for (int64 i = 0; i < operand_shape.rank(); ++i) {
if (i != feature_index) {
dimensions_without_feature.push_back(i);
}
@ -357,7 +357,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
for (int64 i = 0; i < operand_shape.rank(); ++i) {
if (i != feature_index) {
dimensions_without_feature.push_back(i);
}
@ -494,7 +494,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) {
for (int64 i = 0; i < activation_shape.rank(); ++i) {
if (i != feature_index) {
dimensions_without_feature.push_back(i);
}

View File

@ -1550,7 +1550,7 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
}
// Return whether the given shape is rank 2.
static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; }
static bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.

View File

@ -535,7 +535,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
higher_dimensions *= normalized_keys_shape.dimensions(i);
}
int64 lower_dimensions = 1;
for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
for (int64 i = normalized_keys_shape.rank() - 1;
i > physical_dimension_to_sort; --i) {
lower_dimensions *= normalized_keys_shape.dimensions(i);
}
@ -779,8 +779,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
const auto init_value = select_and_scatter->operand(2);
const Window& window = select_and_scatter->window();
PrimitiveType operand_element_type = operand->shape().element_type();
const int64 rank = ShapeUtil::Rank(operand->shape());
CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
const int64 rank = operand->shape().rank();
CHECK_EQ(rank, source->shape().rank());
CHECK_EQ(rank, window.dimensions_size());
// TODO(b/31410564): Implement dilation for select-and-scatter.

View File

@ -173,7 +173,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
// Find out the new dynamic dimension after reduce.
int64 dimensions_not_reduced_count = 0;
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
for (int i = 0; i < operand->shape().rank(); ++i) {
if (dimension == i) {
parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
dynamic_size);
@ -207,7 +207,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
result_dim_mapping[i] = current_result_dims++;
}
for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) {
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
if (!absl::c_linear_search(
dimension_numbers.lhs_contracting_dimensions(), i)) {
if (operand_index == 0) {
@ -217,7 +217,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
}
}
for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) {
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
if (!absl::c_linear_search(
dimension_numbers.rhs_contracting_dimensions(), i) &&
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),

View File

@ -121,10 +121,11 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const {
dynamic_dimension.parameter_index));
TF_RET_CHECK(
dynamic_dimension.dimension <
ShapeUtil::Rank(ShapeUtil::GetSubshape(
ShapeUtil::GetSubshape(
entry->parameter_instruction(dynamic_dimension.parameter_num)
->shape(),
dynamic_dimension.parameter_index)));
dynamic_dimension.parameter_index)
.rank());
return Status::OK();
});
}

View File

@ -1327,9 +1327,9 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
// If implicit broadcast is needed, the source dimensions that are broadcast
// have index 0.
CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
CHECK_EQ(operand_shape.rank(), hlo.shape().rank());
llvm_ir::IrArray::Index source_index(target_index.GetType());
for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
source_index.push_back(target_index[i]);
} else {
@ -1750,7 +1750,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
const llvm_ir::IrArray::Index& index) {
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
const int64 rank = input_hlo->shape().rank();
// Use the same index type for all tensor accesses in the same kernel.
llvm::Type* index_type = index.GetType();
llvm_ir::IrArray::Index slice_start_index(index_type, rank);
@ -1893,7 +1893,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* update_hlo = hlo->operand(1);
const HloInstruction* start_hlo = hlo->operand(2);
// Calculate slice start/end indices.
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
const int64 rank = input_hlo->shape().rank();
llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank);
llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank);
// Slice intersection gathers (ANDs) conditions on all ranks for which
@ -2225,7 +2225,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
auto* iota = Cast<HloIotaInstruction>(hlo);
PrimitiveType element_type = iota->shape().element_type();
IrArray::Index elem_index =
ShapeUtil::Rank(iota->shape()) > 1
iota->shape().rank() > 1
? target_index.SourceIndexOfBroadcast(
iota->shape(),
ShapeUtil::MakeShapeWithDescendingLayout(

View File

@ -50,10 +50,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
auto* zero = comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank());
bool added_padding = false;
for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) {
for (int64 dim = 0; dim < shape.rank(); ++dim) {
if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
continue;
}

View File

@ -219,7 +219,7 @@ bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
Window new_backward_conv_window = backward_conv->window();
// input_padding_config is the config of the kPad to be inserted.
PaddingConfig input_padding_config =
MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
MakeNoPaddingConfig(input->shape().rank());
ConvolutionDimensionNumbers backward_conv_dnums =
backward_conv->convolution_dimension_numbers();
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {

View File

@ -315,8 +315,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
ShapeUtil::Rank(output_shape_));
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;

View File

@ -43,14 +43,14 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
const Layout* max_rank_layout;
for (HloInstruction* param : params) {
if (ShapeUtil::IsArray(param->shape()) &&
ShapeUtil::Rank(param->shape()) > max_rank) {
max_rank = ShapeUtil::Rank(param->shape());
param->shape().rank() > max_rank) {
max_rank = param->shape().rank();
max_rank_layout = &param->shape().layout();
}
}
return absl::c_all_of(params, [&](HloInstruction* param) {
return (!ShapeUtil::IsArray(param->shape())) ||
(ShapeUtil::Rank(param->shape()) < max_rank) ||
(param->shape().rank() < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
});
}

View File

@ -196,9 +196,9 @@ Status GpuLayoutAssignment::AddBackendConstraints(
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
ShapeUtil::Rank(instruction->shape()));
instruction->shape().rank());
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2);
CHECK_LT(batch_dim, instruction->shape().rank() - 2);
}
// Set both inputs and the output to default layout.
@ -215,11 +215,11 @@ Status GpuLayoutAssignment::AddBackendConstraints(
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(output_shape, instruction));
} else if (instruction->opcode() == HloOpcode::kSort &&
ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) {
instruction->operand(0)->shape().rank() > 1) {
// Make sure that all the operands and the output(s) have the same layout.
Shape keys_shape = instruction->operand(0)->shape();
Layout keys_layout =
LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape));
LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
for (int64 i = 0; i < instruction->operand_count(); ++i) {
Shape shape = instruction->operand(i)->shape();
*shape.mutable_layout() = keys_layout;

View File

@ -40,7 +40,7 @@ namespace {
// Return whether the given shape is rank 2 excluding the batch dimensions.
bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2;
return shape.rank() == batch_dimensions_size + 2;
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes

View File

@ -698,8 +698,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
const auto* source = select_and_scatter->operand(1);
const Window& window = select_and_scatter->window();
PrimitiveType operand_element_type = operand->shape().element_type();
const int64 rank = ShapeUtil::Rank(operand->shape());
CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
const int64 rank = operand->shape().rank();
CHECK_EQ(rank, source->shape().rank());
CHECK_EQ(rank, window.dimensions_size());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
@ -1015,7 +1015,7 @@ Status IrEmitterUnnested::EmitScatter(
int64 raw_window_multidim_idx = 0;
std::vector<llvm::Value*> input_window_multidim;
std::vector<int64> input_window_bounds;
for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) {
for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
input_window_bounds.push_back(1); // Trivial dimension.
input_window_multidim.push_back(index.GetConstantWithIndexType(0));
@ -1027,12 +1027,11 @@ Status IrEmitterUnnested::EmitScatter(
++raw_window_multidim_idx;
}
}
DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape()));
DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
// Insert a 1 dimension at the end if index_vector_dim requests one.
Shape scatter_indices_shape = scatter_indices->shape();
if (dim_numbers.index_vector_dim() ==
ShapeUtil::Rank(scatter_indices_shape)) {
if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
scatter_indices_shape.add_dimensions(1);
scatter_indices_shape.mutable_layout()->add_minor_to_major(
dim_numbers.index_vector_dim());
@ -3191,7 +3190,7 @@ Status AreFusedReductionOutputsConsistent(
// dimensions from minor to major.
DimensionVector GetDimensionsToKeepMinorToMajor(
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0);
DimensionVector input_dims(input_shape.rank(), 0);
absl::c_iota(input_dims, 0);
DimensionVector input_dims_to_keep;
for (int input_dim : input_dims) {
@ -3231,7 +3230,7 @@ std::tuple<int64, int64, int64> GetReductionToVectorDimensions(
if (input_dims_to_keep_minor_to_major.empty()) {
return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
}
DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0);
DimensionVector input_dims(input_shape.rank(), 0);
absl::c_iota(input_dims, 0);
absl::Span<const int64> minor_to_major =
LayoutUtil::MinorToMajor(input_shape);

View File

@ -189,8 +189,7 @@ StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
for (const HloInstruction* operand : operands) {
CHECK_EQ(computation, operand->parent());
operand_shapes.push_back(&operand->shape());
max_operand_rank =
std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
}
std::vector<int64> map_dims(max_operand_rank);
std::iota(map_dims.begin(), map_dims.end(), 0);
@ -207,7 +206,7 @@ StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
HloOpcode binary_opcode,
HloModule* module) {
DCHECK_NE(nullptr, module);
std::vector<int64> all_dims(ShapeUtil::Rank(operand->shape()));
std::vector<int64> all_dims(operand->shape().rank());
std::iota(all_dims.begin(), all_dims.end(), 0);
auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});

View File

@ -443,7 +443,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
CHECK(ShapeUtil::IsArray(reference_shape));
const int64 rank = ShapeUtil::Rank(reference_shape);
const int64 rank = reference_shape.rank();
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
CHECK_LT(concat_dim, rank);
@ -1036,11 +1036,9 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
TF_RET_CHECK(broadcast->dimensions().size() ==
ShapeUtil::Rank(operand.shape()))
TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
<< "broadcast dimensions is of size: " << broadcast->dimensions().size()
<< " and rank of operand_to_broadcast is: "
<< ShapeUtil::Rank(operand.shape());
<< " and rank of operand_to_broadcast is: " << operand.shape().rank();
// Checks that operand's dimensions are the same as the broadcast's
// dimensions along the dimensions to be broadcasted.
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
@ -1251,7 +1249,7 @@ template <typename KeyType, typename ValueType>
StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
const Literal& keys_literal,
const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
auto rank = keys_literal.shape().rank();
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
<< "Sort keys and values must have the same dimensions";

View File

@ -1005,8 +1005,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_GE(num_spatial_dims, 0);
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
const auto lhs_rank = lhs_shape.rank();
const auto rhs_rank = rhs_shape.rank();
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
@ -1175,8 +1175,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const auto& dnums = dot->dot_dimension_numbers();
const int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
const int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
const int64 lhs_rank = lhs->shape().rank();
const int64 rhs_rank = rhs->shape().rank();
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
@ -1238,8 +1238,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const auto& dnums = dot->dot_dimension_numbers();
const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
const auto lhs_rank = lhs->shape().rank();
const auto rhs_rank = rhs->shape().rank();
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
@ -1329,7 +1329,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK(ShapeUtil::IsArray(pad->operand(0)->shape()));
// Padding value must be scalar.
CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
CHECK_EQ(pad->operand(0)->shape().rank(),
pad->padding_config().dimensions_size());
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
@ -1352,9 +1352,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& evaluated_operand =
parent_->GetEvaluatedLiteralFor(pad->operand(0));
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
std::vector<int64> input_index(evaluated_operand.shape().rank(), 0);
std::vector<int64> target_index(result.shape().rank(), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
@ -1609,7 +1608,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
int64 sort_dim = sort->dimensions(0);
int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
int64 rank = ShapeUtil::Rank(keys->shape());
int64 rank = keys->shape().rank();
if (rank == 0) {
// Nothing to sort.
parent_->evaluated_[sort] = keys_literal.Clone();
@ -1868,7 +1867,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
int64 rank = ShapeUtil::Rank(operand_literal.shape());
int64 rank = operand_literal.shape().rank();
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
DimensionVector source_index(rank, 0);
@ -1980,7 +1979,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
operand->shape().element_type(), window_dimension_sizes);
DimensionVector window_index(window.dimensions_size());
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
DimensionVector operand_index(operand_literal.shape().rank());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
Literal result(reduce_window->shape());
@ -2411,7 +2410,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const int64 rank = ShapeUtil::Rank(operand->shape());
const int64 rank = operand->shape().rank();
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto func = [&](absl::Span<const int64> out_index) {
DimensionVector operand_index(rank);
@ -2648,12 +2647,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result = LiteralUtil::CreateR1<NativeT>(data);
if (ShapeUtil::Rank(iota->shape()) > 1) {
if (iota->shape().rank() > 1) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[iota],
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
} else {
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
TF_RET_CHECK(iota->shape().rank() == 1);
parent_->evaluated_[iota] = std::move(result);
}
@ -2683,7 +2682,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
//
// This lets you calculate LI given the multidimensional indices in any order.
static DimensionVector MakeDimMultipliers(const Shape& shape) {
DimensionVector v(ShapeUtil::Rank(shape));
DimensionVector v(shape.rank());
int64 scale = 1;
for (auto dim : LayoutUtil::MinorToMajor(shape)) {
v[dim] = scale;
@ -2700,7 +2699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Shape& window_shape, const Window& window, const Shape& base_shape,
const absl::Span<const int64>& window_count_index,
const std::function<void(const std::vector<int64>&)>& f) {
const int64 rank = ShapeUtil::Rank(base_shape);
const int64 rank = base_shape.rank();
DimensionVector window_index(rank);
std::fill(window_index.begin(), window_index.end(), 0);
do {
@ -2767,7 +2766,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const Literal& start_indices_literal) {
auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
const auto rank = ShapeUtil::Rank(result.shape());
const auto rank = result.shape().rank();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the

View File

@ -1039,7 +1039,7 @@ HloInstruction::CreateBroadcastSequence(
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
adder) {
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
operand->shape().rank() == output_shape.rank());
Shape broadcast_shape = ShapeUtil::ChangeElementType(
output_shape, operand->shape().element_type());
// Do explicit broadcast for scalar.
@ -1055,7 +1055,7 @@ HloInstruction::CreateBroadcastSequence(
// Do explicit broadcast for degenerate broadcast.
std::vector<int64> broadcast_dimensions;
std::vector<int64> reshaped_dimensions;
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
for (int i = 0; i < operand->shape().rank(); i++) {
if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand->shape().dimensions(i));

View File

@ -734,7 +734,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape,
AppendComputation(map_computation);
// TODO(b/65689298) Remove code below once Map is generalized to accept
// arbitrary map dimensions.
dimensions_.resize(ShapeUtil::Rank(shape));
dimensions_.resize(shape.rank());
std::iota(dimensions_.begin(), dimensions_.end(), 0);
}

View File

@ -1980,7 +1980,7 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
}
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
const tensorflow::int64 rank = shape.rank();
// Create a literal with the given shape in default layout.
*literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
@ -2145,7 +2145,7 @@ template <typename LiteralNativeT>
bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
tensorflow::int64 rank = shape.rank();
*literal = Literal(shape);

View File

@ -30,7 +30,7 @@ HloSharding HloSharding::AssignDevice(int64 device_id) {
}
HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
CHECK_EQ(1, ShapeUtil::Rank(input_shape));
CHECK_EQ(1, input_shape.rank());
CHECK_GT(num_tiles, 1);
std::vector<int64> dimensions(1, num_tiles);
Array<int64> assignment(dimensions);
@ -340,7 +340,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
}
// The tile assignment tensor must have the same rank as the input.
if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) {
if (shape.rank() != tile_assignment_.num_dimensions()) {
return tensorflow::errors::InvalidArgument(
"Number of tile assignment dimensions is different to the input rank. "
"sharding=",

View File

@ -349,7 +349,7 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
auto* iota = Cast<HloIotaInstruction>(instruction);
const int64 rank = ShapeUtil::Rank(iota->shape());
const int64 rank = iota->shape().rank();
if (rank == 0) {
return InternalError("Iota does not support scalars.");
}
@ -397,13 +397,11 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
const Shape& operand_shape = broadcast->operand(0)->shape();
// Check for mixed precision.
TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
broadcast->dimensions().size());
for (int64 operand_dimension = 0;
operand_dimension < ShapeUtil::Rank(operand_shape);
TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
++operand_dimension) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) &&
TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
output_dimension >= 0 &&
(broadcast->shape().dimensions(output_dimension) ==
operand_shape.dimensions(operand_dimension)))
@ -524,8 +522,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) {
int64 max_operand_rank = 0;
for (const HloInstruction* operand : map->operands()) {
operand_shapes.push_back(&operand->shape());
max_operand_rank =
std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
}
// TODO(b/65689298) Remove code below once Map is generalized to accept
// arbitrary map dimensions.
@ -1271,11 +1268,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
// op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
// or ComputationLowerer::Visit()
TF_RET_CHECK(broadcast->dimensions().size() ==
ShapeUtil::Rank(broadcast->operand(0)->shape()))
broadcast->operand(0)->shape().rank())
<< "Broadcast HLO (" << broadcast->ToShortString()
<< ") has invalid number of dimensions: "
<< broadcast->dimensions().size()
<< " != " << ShapeUtil::Rank(broadcast->operand(0)->shape());
<< " != " << broadcast->operand(0)->shape().rank();
return Status::OK();
}
@ -1376,7 +1373,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
if (LayoutUtil::IsDenseArray(operand_shape) &&
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) {
operand_shape.rank() == result_shape.rank()) {
const Layout& operand_layout = operand_shape.layout();
TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
<< "Instruction shouldn't change layouts "

View File

@ -1002,7 +1002,7 @@ bool CanFoldDotIntoIndexedArray(
absl::Span<const int64> contracting_dims,
absl::Span<const int64> batch_dims) {
absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
contracting_dims, batch_dims);
if (!non_contracting_non_batch_dim.has_value()) {
VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
@ -1015,7 +1015,7 @@ bool CanFoldDotIntoIndexedArray(
return false;
}
int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape());
int64 indexed_array_rank = indexed_array->shape().rank();
if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
// This restriction can be lifted by inserting reshape nodes.
VLOG(3) << tag
@ -1043,7 +1043,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
return nullptr;
}
int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
int64 lhs_rank = lhs->shape().rank();
DotDimensionNumbers new_dim_numbers = dim_numbers;
new_dim_numbers.set_lhs_contracting_dimensions(
0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
@ -1078,7 +1078,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
return nullptr;
}
int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
int64 rhs_rank = rhs->shape().rank();
DotDimensionNumbers new_dim_numbers = dim_numbers;
new_dim_numbers.set_rhs_contracting_dimensions(

View File

@ -991,8 +991,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
CHECK(ShapeUtil::IsArray(instruction->shape()));
CHECK(ShapeUtil::IsArray(operand->shape()));
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape()) &&
operand->shape().rank() == instruction->shape().rank() &&
!instruction_can_change_layout_func_(instruction)) {
// Propagate the result layout to the operand layout if the instruction
// requires the same layout out for the result and the operand.
@ -1012,7 +1011,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
// operations. For similar reasons, if the operand and output have the same
// rank, try to match the operand's layout to the output.
if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
ShapeUtil::Rank(instruction->shape()) == 1) {
instruction->shape().rank() == 1) {
// Don't assign a layout in case of R1 -> effective R1 reshape.
return nullptr;
}
@ -1026,7 +1025,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
return absl::make_unique<Layout>(operand_shape.layout());
}
if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
if (operand_shape.rank() == output_shape.rank()) {
*operand_shape.mutable_layout() = output_layout;
if (ShapeUtil::ReshapeIsBitcast(operand_shape,
output_shape_with_layout)) {
@ -1045,7 +1044,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (instruction->opcode() == HloOpcode::kTranspose) {
// Pick the operand layout that makes the transpose a bitcast.
int64 rank = ShapeUtil::Rank(instruction->shape());
int64 rank = instruction->shape().rank();
std::vector<int64> new_minor_to_major(rank);
for (int64 i = 0; i < rank; ++i) {
int64 output_dim = LayoutUtil::Minor(output_layout, i);
@ -1070,7 +1069,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
ShapeUtil::IsArray(operand->shape()));
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
operand->shape().rank() == user->shape().rank() &&
!instruction_can_change_layout_func_(user)) {
// Assign users the same layout as the operand.
return absl::make_unique<Layout>(operand_layout);
@ -1083,7 +1082,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
// reshape is a bitcast when using the same layout. This may avoid copy
// operations. For similar reasons, if the operand and output have the same
// rank, try to match the outputs's layout to the operand.
if (ShapeUtil::Rank(operand->shape()) == 1 &&
if (operand->shape().rank() == 1 &&
ShapeUtil::TrueRank(user->shape()) == 1) {
// Don't assign a layout in case of R1 -> effective R1 reshape.
return nullptr;
@ -1098,7 +1097,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
return absl::make_unique<Layout>(output_shape.layout());
}
if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
if (operand->shape().rank() == output_shape.rank()) {
*output_shape.mutable_layout() = operand_layout;
if (ShapeUtil::ReshapeIsBitcast(output_shape,
operand_shape_with_layout)) {
@ -1117,7 +1116,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
if (user->opcode() == HloOpcode::kTranspose) {
// Pick the user layout that makes the transpose a bitcast.
int64 rank = ShapeUtil::Rank(user->shape());
int64 rank = user->shape().rank();
std::vector<int64> new_minor_to_major(rank);
auto inverse_dimensions = InversePermutation(user->dimensions());
for (int64 i = 0; i < rank; ++i) {
@ -1273,7 +1272,7 @@ Status LayoutAssignment::PropagateOperandConstraint(
return Status::OK();
}
int64 operand_rank = ShapeUtil::Rank(operand->shape());
int64 operand_rank = operand->shape().rank();
if (operand_rank <= 1) {
return Status::OK();
}
@ -1288,7 +1287,7 @@ Status LayoutAssignment::PropagateOperandConstraint(
continue;
}
const HloInstruction* sibling = user->operand(operand_no);
const int64 sibling_rank = ShapeUtil::Rank(sibling->shape());
const int64 sibling_rank = sibling->shape().rank();
if (sibling_rank <= 1) {
continue;
}
@ -1320,13 +1319,13 @@ Status LayoutAssignment::PropagateOperandConstraint(
if (ShapeUtil::IsTuple(subshape)) {
return Status::OK();
}
if (ShapeUtil::Rank(subshape) <= 1) {
if (subshape.rank() <= 1) {
return Status::OK();
}
// Assign the right layout to input fusion of higher rank reduce
// operations.
if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) {
if (subshape.rank() != operand->shape().rank()) {
return Status::OK();
}
// TODO(b/67641796): Are there cases except fusion that use this code
@ -1357,7 +1356,7 @@ Status LayoutAssignment::PropagateOperandConstraint(
if (ShapeUtil::IsTuple(subshape)) {
return Status::OK();
}
if (ShapeUtil::Rank(subshape) <= 1) {
if (subshape.rank() <= 1) {
return Status::OK();
}
TF_ASSIGN_OR_RETURN(
@ -1402,7 +1401,7 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands(
if (!instruction_can_change_layout_func_(instruction)) {
// Copy the layout to the operand.
if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
operand->shape().rank() ==
LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
buffer_constraint.layout(), instruction, operand_no,
@ -2101,7 +2100,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
/* static */
bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
if (ShapeUtil::IsArray(shape)) {
return ShapeUtil::Rank(shape) <= 1;
return shape.rank() <= 1;
}
return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
return IsAtMostRank1(subshape);

View File

@ -528,8 +528,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
for (int64 operand_no = 0; operand_no < instruction->operand_count();
++operand_no) {
const HloInstruction* operand = instruction->operand(operand_no);
if (ShapeUtil::Rank(instruction->shape()) !=
ShapeUtil::Rank(operand->shape())) {
if (instruction->shape().rank() != operand->shape().rank()) {
continue;
}
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(

View File

@ -44,7 +44,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
const Shape& output_shape = output_array.GetShape();
// Read start indices from start_indices_generator.
const int64 rank = ShapeUtil::Rank(output_shape);
const int64 rank = output_shape.rank();
IrArray::Index start_index(b->getInt64Ty(), rank);
for (int64 i = 0; i < rank; ++i) {
IrArray::Index dim_index({b->getInt64(i)});

View File

@ -61,7 +61,7 @@ void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
llvm::IRBuilder<>* b)
: multidim_(ShapeUtil::Rank(shape)),
: multidim_(shape.rank()),
linear_(linear),
layout_(shape.layout()),
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
@ -120,7 +120,7 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) {
DCHECK(depth == 1 || depth == 0) << depth;
} else {
DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString();
DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString();
}
}
@ -137,12 +137,12 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
const Shape& output_shape, const Shape& input_shape,
llvm::IRBuilder<>* builder) const {
const auto& target_index = *this;
CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
CHECK_EQ(target_index.size(), output_shape.rank());
std::vector<std::pair<int64, int64>> common_factors =
CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
std::vector<llvm::Value*> source_multidim_index(
ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_));
input_shape.rank(), llvm::UndefValue::get(index_type_));
// We compute the source indices in each common factor from only the target
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
@ -257,7 +257,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
const Shape& shape, const Shape& operand_shape,
absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
int64 rank = ShapeUtil::Rank(operand_shape);
int64 rank = operand_shape.rank();
std::vector<llvm::Value*> source_index(rank);
for (int64 i = 0; i < rank; ++i) {
source_index[i] = multidim_[dimension_mapping[i]];
@ -271,7 +271,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
// The other dimensions can be masked out with a div and a mod operation.
std::vector<int64> logical_to_physical =
LayoutUtil::MakeLogicalToPhysical(shape.layout());
int64 output_rank = ShapeUtil::Rank(shape);
int64 output_rank = shape.rank();
// The minimum physical dimension that is broadcasted.
int64 min_broadcasted_dimension = output_rank;
// The maximum physical dimension that is broadcasted.
@ -348,7 +348,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
// over higher-rank arrays.
return base_ptr_;
}
CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_));
CHECK_EQ(index.size(), shape_->rank());
if (index.LinearValidOnShape(*shape_)) {
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();

View File

@ -235,7 +235,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
absl::string_view suffix) {
std::vector<int64> dimensions(ShapeUtil::Rank(shape));
std::vector<int64> dimensions(shape.rank());
std::iota(dimensions.begin(), dimensions.end(), 0);
return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
}

View File

@ -322,7 +322,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
// comparisons).
const Shape& keys_shape = keys_array.GetShape();
int64 rank = ShapeUtil::Rank(keys_shape);
int64 rank = keys_shape.rank();
int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
std::vector<int64> dimensions_in_iteration_order(rank);
std::vector<int64> iteration_order_to_logical_order(rank);

View File

@ -831,7 +831,7 @@ class ShapePatternRankImpl {
explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
bool Match(const ::xla::Shape* shape, MatchOption option) const {
if (ShapeUtil::Rank(*shape) != rank_) {
if (shape->rank() != rank_) {
if (rank_ == 0) {
EXPLAIN << "Shape is not a scalar";
} else {

View File

@ -88,7 +88,7 @@ static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
HloInstruction* updates, absl::Span<const int64> update_window_dims) {
std::vector<int64> permutation;
const int64 updates_rank = ShapeUtil::Rank(updates->shape());
const int64 updates_rank = updates->shape().rank();
permutation.reserve(updates_rank);
for (int64 i = 0; i < updates_rank; ++i) {

View File

@ -96,7 +96,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape,
}
for (const Shape* element_shape : accumulator_subshapes) {
if (ShapeUtil::Rank(*element_shape) != 0) {
if (element_shape->rank() != 0) {
return InvalidArgument(
"Reduction function must return a scalar or tuple of scalars but "
"returns shape: %s",
@ -160,10 +160,10 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const Window& window,
PrimitiveType element_type,
bool allow_negative_padding) {
if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
if (window.dimensions_size() != base_shape.rank()) {
return InvalidArgument(
"Window has dimension %d but base shape has dimension %d.",
window.dimensions_size(), ShapeUtil::Rank(base_shape));
window.dimensions_size(), base_shape.rank());
}
std::vector<int64> output_dimensions(window.dimensions_size());
@ -338,7 +338,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (arg_shapes.empty()) {
return InvalidArgument("Concatenate expects at least one argument.");
}
if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
return InvalidArgument("Concatenate dimension out of bounds: %d.",
dimension);
}
@ -351,12 +351,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
element_type = arg_shape->element_type();
continue;
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
if (arg_shape->rank() != shape->rank()) {
return InvalidArgument(
"Cannot concatenate arrays with different ranks: %d (%s) vs %d "
"(%s).",
ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape),
ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape));
arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
ShapeUtil::HumanString(*shape));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
return InvalidArgument(
@ -364,8 +364,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
PrimitiveType_Name(arg_shape->element_type()),
PrimitiveType_Name(shape->element_type()));
}
for (int64 dimension_number = 0;
dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
++dimension_number) {
if (arg_shape->dimensions(dimension_number) !=
shape->dimensions(dimension_number)) {
if (dimension_number == dimension) {
@ -480,7 +480,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Pad operation does not support non-scalar padding values.");
}
if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
if (operand_shape.rank() != padding_config.dimensions_size()) {
return InvalidArgument(
"The rank of the operand and the padding configuration do not match: "
"%s vs %s.",
@ -500,7 +500,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
padding_config.ShortDebugString());
}
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
std::vector<int64> dimensions(operand_shape.rank());
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
const auto& p = padding_config.dimensions(i);
dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
@ -555,9 +555,9 @@ Status ValidateDotDimensionNumbers(
absl::Span<const int64> rhs_batch_dimensions =
AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
lhs_batch_dimensions) ||
!dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
!dims_in_range(rhs.rank(), rhs_contracting_dimensions,
rhs_batch_dimensions)) {
return InvalidArgument("A dimension number is out of range in Dot: %s.",
dimension_numbers.DebugString());
@ -583,12 +583,10 @@ Status ValidateDotDimensionNumbers(
// Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
const int64 lhs_non_contracting_non_batch_dims =
ShapeUtil::Rank(lhs) -
dimension_numbers.lhs_contracting_dimensions_size() -
lhs.rank() - dimension_numbers.lhs_contracting_dimensions_size() -
dimension_numbers.lhs_batch_dimensions_size();
const int64 rhs_non_contracting_non_batch_dims =
ShapeUtil::Rank(rhs) -
dimension_numbers.rhs_contracting_dimensions_size() -
rhs.rank() - dimension_numbers.rhs_contracting_dimensions_size() -
dimension_numbers.rhs_batch_dimensions_size();
if (lhs_non_contracting_non_batch_dims < 0 ||
lhs_non_contracting_non_batch_dims > 1 ||
@ -637,7 +635,7 @@ Status ValidateDotDimensionNumbers(
return fail("Element types do not match.");
}
if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) {
if ((lhs.rank() < 1) || (rhs.rank() < 1)) {
return fail("Dot only supports rank 1 or above.");
}
@ -686,12 +684,12 @@ Status ValidateDotDimensionNumbers(
std::unordered_set<int64> rhs_batch_dims(
dimension_numbers.rhs_batch_dimensions().begin(),
dimension_numbers.rhs_batch_dimensions().end());
for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
for (int64 i = 0; i < lhs.rank(); i++) {
if (i != lhs_contracting_dimension) {
dimensions.push_back(lhs.dimensions(i));
}
}
for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
for (int64 i = 0; i < rhs.rank(); i++) {
if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) {
dimensions.push_back(rhs.dimensions(i));
}
@ -708,14 +706,14 @@ Status ValidateDotDimensionNumbers(
ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& lhs,
const Shape& rhs) {
TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
TF_RET_CHECK(lhs.rank() == rhs.rank());
// The shapes have to be compatible. That is, if some dimension d has a
// different size in the two shapes, one of them has to be 1 (a "degenerate"
// dimension). In that case, the output shape has the non-1 dimension size
// from the lhs/rhs pair in every index.
std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
std::vector<int64> output_dimensions(lhs.rank());
for (int64 i = 0; i < lhs.rank(); ++i) {
if (lhs.dimensions(i) == rhs.dimensions(i)) {
output_dimensions[i] = lhs.dimensions(i);
} else if (lhs.dimensions(i) == 1) {
@ -743,13 +741,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument("Automatic shape inference not supported: %s and %s",
ShapeUtil::HumanString(smaller_shape),
ShapeUtil::HumanString(larger_shape));
} else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
} else if (broadcast_dimensions.size() != smaller_shape.rank()) {
return InvalidArgument(
"Size of broadcast_dimensions has to match lower-rank operand's "
"rank; "
" lower-rank operand's rank is %d, size of broadcast_dimensions is "
"%u.",
ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
smaller_shape.rank(), broadcast_dimensions.size());
}
// broadcast_dimensions is a sequence of dimensions; its length is equal to
@ -847,8 +845,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(rhs));
}
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
std::vector<int64> identity_dims(ShapeUtil::Rank(lhs));
if (lhs.rank() == rhs.rank()) {
std::vector<int64> identity_dims(lhs.rank());
std::iota(identity_dims.begin(), identity_dims.end(), 0);
if (!broadcast_dimensions.empty() &&
broadcast_dimensions != identity_dims) {
@ -865,15 +863,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
}
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
if (lhs.rank() == rhs.rank()) {
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
} else {
// Ranks do not match, so perform InDim broadcasting using
// broadcast_dimensions. Scalar broadcasting is a special case of this.
const Shape& larger_shape =
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
const Shape& smaller_shape =
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
// After InDim broadcasting, perform degenerate dimensions broadcasting.
TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
@ -1162,12 +1158,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
Status::OK());
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
if (feature_index >= operand_shape.rank()) {
return InvalidArgument(
"Expected feature_index of batch-norm-training to be "
"smaller than the rank of operand_shape; "
"got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
feature_index, operand_shape.rank());
}
if (feature_index < 0) {
@ -1177,25 +1173,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
if (operand_shape.rank() < 1) {
return InvalidArgument(
"Expected the rank of operand to "
"batch-norm-training to be at least 1; got %d.",
ShapeUtil::Rank(operand_shape));
operand_shape.rank());
}
if (ShapeUtil::Rank(offset_shape) != 1) {
if (offset_shape.rank() != 1) {
return InvalidArgument(
"Offset input of batch-norm-training must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(offset_shape));
offset_shape.rank());
}
if (ShapeUtil::Rank(scale_shape) != 1) {
if (scale_shape.rank() != 1) {
return InvalidArgument(
"Scale input of batch-norm-training must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
scale_shape.rank());
}
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
@ -1272,12 +1268,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
Status::OK());
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
if (feature_index >= operand_shape.rank()) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to be "
"smaller than the rank of operand_shape; "
"got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
feature_index, operand_shape.rank());
}
if (feature_index < 0) {
@ -1287,25 +1283,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
if (operand_shape.rank() < 1) {
return InvalidArgument(
"Expected the rank of operand to "
"batch-norm-inference to be at least 1; got %d.",
ShapeUtil::Rank(operand_shape));
operand_shape.rank());
}
if (ShapeUtil::Rank(offset_shape) != 1) {
if (offset_shape.rank() != 1) {
return InvalidArgument(
"Offset input of batch-norm-inference must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(offset_shape));
offset_shape.rank());
}
if (ShapeUtil::Rank(scale_shape) != 1) {
if (scale_shape.rank() != 1) {
return InvalidArgument(
"Scale input of batch-norm-inference must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
scale_shape.rank());
}
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
@ -1417,41 +1413,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
if (feature_index >= operand_shape.rank()) {
return InvalidArgument(
"Expected feature_index of batch-norm-grad to be "
"smaller than the rank of operand_shape; "
"got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
feature_index, operand_shape.rank());
}
if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
if (operand_shape.rank() != output_grad_shape.rank()) {
return InvalidArgument(
"Expected operand_shape of batch-norm-grad to have the same rank as"
" output_grad_shape; got rank(oprand_shape) %d, and"
" rank(output_grad_shape) %d.",
ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
operand_shape.rank(), output_grad_shape.rank());
}
if (ShapeUtil::Rank(mean_shape) != 1) {
if (mean_shape.rank() != 1) {
return InvalidArgument(
"Mean input of batch-norm-grad must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(mean_shape));
mean_shape.rank());
}
if (ShapeUtil::Rank(scale_shape) != 1) {
if (scale_shape.rank() != 1) {
return InvalidArgument(
"Scale input of batch-norm-grad must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
scale_shape.rank());
}
if (ShapeUtil::Rank(var_shape) != 1) {
if (var_shape.rank() != 1) {
return InvalidArgument(
"Var input of batch-norm-grad must have"
" rank 1, but has rank %d.",
ShapeUtil::Rank(var_shape));
var_shape.rank());
}
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
@ -1538,7 +1534,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
// Verify operand_shape and output_grad_shape have same bounds.
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
for (int64 i = 0; i < operand_shape.rank(); ++i) {
if (ShapeUtil::GetDimension(operand_shape, i) !=
ShapeUtil::GetDimension(output_grad_shape, i)) {
return InvalidArgument(
@ -1603,12 +1599,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
const int num_dims = num_spatial_dims + 2;
if (ShapeUtil::Rank(lhs) != num_dims) {
if (lhs.rank() != num_dims) {
return InvalidArgument(
"The LHS argument to a convolution should have rank %d; lhs: %s.",
num_dims, ShapeUtil::HumanString(lhs));
}
if (ShapeUtil::Rank(rhs) != num_dims) {
if (rhs.rank() != num_dims) {
return InvalidArgument(
"The RHS argument to a convolution should have rank %d; rhs: %s.",
num_dims, ShapeUtil::HumanString(rhs));
@ -1853,12 +1849,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& shape, int64 split_dimension, int64 concat_dimension,
int64 split_count) {
TF_RET_CHECK(split_count > 0);
if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) {
if (split_dimension >= shape.rank() || split_dimension < 0) {
return InvalidArgument(
"AllToAll split_dimension %d is out-of-bounds in shape %s.",
split_dimension, ShapeUtil::HumanString(shape));
}
if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) {
if (concat_dimension >= shape.rank() || concat_dimension < 0) {
return InvalidArgument(
"AllToAll concat_dimension %d is out-of-bounds in shape %s.",
concat_dimension, ShapeUtil::HumanString(shape));
@ -1932,7 +1928,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// doesn't matter which one we choose.
const Shape& arg = *reduced_args[0];
for (int64 dimension : dimensions_to_reduce) {
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
if (dimension >= arg.rank() || dimension < 0) {
return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
dimension, ShapeUtil::HumanString(arg));
}
@ -1949,7 +1945,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
std::vector<int64> new_dimensions;
for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
for (int i = 0; i < arg.rank(); ++i) {
if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
new_dimensions.push_back(arg.dimensions(i));
}
@ -2041,7 +2037,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
const Shape& shape, int64 dimension) {
if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) {
if (dimension < 0 || dimension >= shape.rank()) {
return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
dimension);
}
@ -2083,10 +2079,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
starts.size(), strides.size()));
}
if (starts.size() != ShapeUtil::Rank(arg)) {
if (starts.size() != arg.rank()) {
return InvalidArgument(
"Slice index count does not match argument rank: %u vs %d.",
starts.size(), ShapeUtil::Rank(arg));
starts.size(), arg.rank());
}
std::vector<int64> sizes;
@ -2132,10 +2128,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(operand_shape),
ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", "));
if (ShapeUtil::Rank(start_indices_shape) != 1) {
if (start_indices_shape.rank() != 1) {
return InvalidArgument(
"Dynamic slice start indices of rank %d must be rank1.",
ShapeUtil::Rank(start_indices_shape));
start_indices_shape.rank());
}
if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
@ -2144,18 +2140,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
const int64 start_num_dims = start_indices_shape.dimensions(0);
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
if (operand_shape.rank() != start_num_dims) {
return InvalidArgument(
"Dynamic slice start number of dimensions %d (%s) must match rank "
"%d of slice input (%s).",
start_num_dims, ShapeUtil::HumanString(start_indices_shape),
ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
}
if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
if (slice_sizes.size() != operand_shape.rank()) {
return InvalidArgument(
"Dynamic slice index count does not match argument rank: %u vs %d.",
slice_sizes.size(), ShapeUtil::Rank(operand_shape));
slice_sizes.size(), operand_shape.rank());
}
for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
@ -2193,10 +2189,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(start_indices_shape),
ShapeUtil::HumanString(update_shape));
if (ShapeUtil::Rank(start_indices_shape) != 1) {
if (start_indices_shape.rank() != 1) {
return InvalidArgument(
"Dynamic update slice start indices of rank %d must be rank1.",
ShapeUtil::Rank(start_indices_shape));
start_indices_shape.rank());
}
if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
@ -2205,19 +2201,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
const int64 start_num_dims = start_indices_shape.dimensions(0);
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
if (operand_shape.rank() != start_num_dims) {
return InvalidArgument(
"Dynamic update slice start number of dimensions %d (%s) must match "
"rank %d of slice input (%s).",
start_num_dims, ShapeUtil::HumanString(start_indices_shape),
ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
}
if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
if (update_shape.rank() != operand_shape.rank()) {
return InvalidArgument(
"Dynamic update slice update rank does not match argument rank: "
"%d vs %d.",
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
update_shape.rank(), operand_shape.rank());
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
@ -2229,7 +2225,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
PrimitiveType_Name(update_shape.element_type()));
}
for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
const int64 input_dim_size = operand_shape.dimensions(dim);
const int64 update_dim_size = update_shape.dimensions(dim);
if (update_dim_size < 0) {
@ -2255,7 +2251,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument("a dimension number is duplicated in reverse");
}
for (int64 dimension : dimensions) {
if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
if (dimension >= operand_shape.rank() || dimension < 0) {
return InvalidArgument(
"One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
dimension, ShapeUtil::HumanString(operand_shape));
@ -2397,8 +2393,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
const int64 operand_rank = ShapeUtil::Rank(operand_shape);
const int64 output_rank = ShapeUtil::Rank(output_shape);
const int64 operand_rank = operand_shape.rank();
const int64 output_rank = output_shape.rank();
if (operand_rank > output_rank) {
return InvalidArgument(
"InDim style broadcast must be to an equal or higher ranked shape; "
@ -2457,9 +2453,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(inferred_shape));
}
std::vector<int64> indices(ShapeUtil::Rank(operand));
std::vector<int64> indices(operand.rank());
std::iota(indices.begin(), indices.end(), 0);
if (dimensions.size() != ShapeUtil::Rank(operand) ||
if (dimensions.size() != operand.rank() ||
!std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) {
return InvalidArgument(
@ -2475,9 +2471,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
std::vector<int64> indices(operand.rank());
std::iota(indices.begin(), indices.end(), 0);
if (dimensions.size() != ShapeUtil::Rank(operand) ||
if (dimensions.size() != operand.rank() ||
!std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) {
return InvalidArgument(
@ -2829,7 +2825,7 @@ Status ValidateScatterDimensionNumbers(
"update_window_dims in scatter op must not repeat; got: %s.",
StrJoin(dim_numbers.update_window_dims(), ", "));
}
const int64 updates_rank = ShapeUtil::Rank(updates_shape);
const int64 updates_rank = updates_shape.rank();
for (int64 window_dim : dim_numbers.update_window_dims()) {
if (window_dim < 0 || window_dim >= updates_rank) {
return InvalidArgument(
@ -2863,10 +2859,10 @@ Status ValidateScatterDimensionNumbers(
// Validate window size.
auto window_size = dim_numbers.update_window_dims_size() +
dim_numbers.inserted_window_dims_size();
if (window_size != ShapeUtil::Rank(operand_shape)) {
if (window_size != operand_shape.rank()) {
return InvalidArgument(
"Scatter op has window of size %d; doesn't match operand of rank %d.",
window_size, ShapeUtil::Rank(operand_shape));
window_size, operand_shape.rank());
}
// Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
@ -2951,10 +2947,9 @@ Status ValidateScatterDimensionNumbers(
int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
scatter_dim_numbers.update_window_dims_size();
if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) {
if (updates_shape.rank() != expected_updates_rank) {
return InvalidArgument("Updates tensor must be of rank %d; got %d.",
expected_updates_rank,
ShapeUtil::Rank(updates_shape));
expected_updates_rank, updates_shape.rank());
}
TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
@ -2985,7 +2980,7 @@ Status ValidateScatterDimensionNumbers(
}
int64 scatter_dims_seen = 0;
for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
for (int64 i = 0; i < updates_shape.rank(); ++i) {
bool is_update_window_dim =
absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {

View File

@ -45,7 +45,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
auto& operand = *dot.operand(i);
if (operand.IsRank2Transpose()) {
operand_set.push_back(i);
} else if (ShapeUtil::Rank(operand.shape()) != 2) {
} else if (operand.shape().rank() != 2) {
return {};
}
}

View File

@ -61,6 +61,12 @@ string Shape::ToString(bool print_layout) const {
}
}
int64 Shape::rank() const {
CHECK(ShapeUtil::IsArray(*this))
<< "Non-arrays do not have a rank, shape: " << *this;
return dimensions_.size();
}
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
out << shape.ToString(/*print_layout=*/true);
return out;

View File

@ -44,6 +44,10 @@ class Shape {
// without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
string ToString(bool print_layout = false) const;
// Returns the rank (number of dimensions) of the given shape. Shape must be
// an array.
int64 rank() const;
// The following methods mirror the protobuf generated code interface for the
// message ShapeProto. This enabled easy migration of this data structure
// from a proto to a proper C++ class.

View File

@ -679,8 +679,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
CHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) *
ShapeUtil::Rank(shape) * sizeof(int64);
return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() *
sizeof(int64);
}
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
@ -763,7 +763,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return sparse_elements_size;
}
int64 sparse_indices_size =
MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape));
MultiplyWithoutOverflow(max_sparse_elements, shape.rank());
if (sparse_indices_size < 0) {
return sparse_indices_size;
}

View File

@ -298,6 +298,7 @@ class ShapeUtil {
// Returns the rank (number of dimensions) of the given shape.
// Precondition: !IsTuple(shape)
ABSL_DEPRECATED("Use `Shape::rank` instead.")
static int64 Rank(const Shape& shape);
// Returns the number of dimensions for which the dimension is not (trivially)

View File

@ -79,7 +79,7 @@ void SparseIndexArray::Resize(int64 num_indices) {
}
bool SparseIndexArray::Validate(const Shape& shape) const {
if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) {
if (rank_ == 0 || rank_ != shape.rank()) {
return false;
}
int64 num_indices = index_count();

View File

@ -191,7 +191,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
std::vector<int64> minor_to_major(expected.shape().rank());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto layout = ShapeUtil::MakeShapeWithLayout(
@ -234,7 +234,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::vector<int64> minor_to_major(literal.shape().rank());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =

View File

@ -348,7 +348,7 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
? use->shape()
: use->operand(1)->shape();
const int64 rank = ShapeUtil::Rank(indexed_shape);
const int64 rank = indexed_shape.rank();
if (!index_space.empty()) {
TF_RET_CHECK(rank == index_space.size());
for (int64 i = 0; i < rank; ++i) {
@ -459,8 +459,8 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
HloInstruction* lhs,
HloInstruction* rhs) {
CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
CHECK_EQ(lhs->shape().rank(), 2);
CHECK_EQ(rhs->shape().rank(), 2);
PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
2, PrecisionConfig::DEFAULT);