Replace calls to ShapeUtil::Rank with Shape::rank.
No functional change. ShapeUtil::Rank is marked as deprecated. A later CL will remove it. PiperOrigin-RevId: 226412117
This commit is contained in:
parent
e7bac9435d
commit
1ed59e52b1
@ -129,7 +129,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||
std::vector<string> dim_vars;
|
||||
string dim_sizes, indices;
|
||||
if (xla::ShapeUtil::Rank(shape) == 0 ||
|
||||
if (shape.rank() == 0 ||
|
||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||
dim_sizes = "[1]";
|
||||
indices = "[0]";
|
||||
|
@ -39,7 +39,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
xla::ShapeUtil::Rank(crops.shape()) == 2 &&
|
||||
crops.shape().rank() == 2 &&
|
||||
block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
|
||||
2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
|
||||
errors::InvalidArgument("crops should have shape [", block_rank,
|
||||
|
@ -110,8 +110,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
auto shape_status = b.GetShape(arg);
|
||||
OP_REQUIRES_OK(ctx, shape_status.status());
|
||||
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
|
||||
*arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout(
|
||||
xla::ShapeUtil::Rank(arg_shape));
|
||||
*arg_shape.mutable_layout() =
|
||||
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
|
||||
arg_shapes.push_back(std::move(arg_shape));
|
||||
}
|
||||
|
||||
|
@ -38,8 +38,7 @@ class MirrorPadOp : public XlaOpKernel {
|
||||
// - [1, 2, 3, 3, 2] in symmetric mode.
|
||||
int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0;
|
||||
xla::XlaOp accum = t;
|
||||
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
|
||||
--dimno) {
|
||||
for (int64 dimno = original_shape.rank() - 1; dimno >= 0; --dimno) {
|
||||
auto t_rev = xla::Rev(accum, {dimno});
|
||||
int64 lhs_padding = pad_literal.Get<int64>({dimno, 0});
|
||||
int64 rhs_padding = pad_literal.Get<int64>({dimno, 1});
|
||||
|
@ -39,7 +39,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
xla::ShapeUtil::Rank(paddings.shape()) == 2 &&
|
||||
paddings.shape().rank() == 2 &&
|
||||
block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
|
||||
2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
|
||||
errors::InvalidArgument("paddings should have shape [", block_rank,
|
||||
|
@ -48,7 +48,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
|
||||
if (indices_are_vectors) {
|
||||
TF_RET_CHECK(!indices_dims.empty());
|
||||
num_index_dims = indices_dims.back();
|
||||
if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) {
|
||||
if (num_index_dims > buffer_shape.rank()) {
|
||||
return errors::InvalidArgument(
|
||||
"The size of the minor dimension of the indices (shape: ",
|
||||
xla::ShapeUtil::HumanString(indices_shape),
|
||||
@ -140,8 +140,8 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
|
||||
? indices_shape.dimensions_size() - 1
|
||||
: indices_shape.dimensions_size());
|
||||
|
||||
int64 updates_rank = xla::ShapeUtil::Rank(updates_shape);
|
||||
int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape);
|
||||
int64 updates_rank = updates_shape.rank();
|
||||
int64 buffer_rank = buffer_shape.rank();
|
||||
int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
|
||||
|
||||
// If the rank of `updates` is 0 and does not match the expected rank of
|
||||
@ -156,7 +156,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
|
||||
if (updates_rank == 0 && expected_updates_rank != 0) {
|
||||
new_updates = xla::Broadcast(updates, expected_updates_dims);
|
||||
TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
|
||||
updates_rank = xla::ShapeUtil::Rank(updates_shape);
|
||||
updates_rank = updates_shape.rank();
|
||||
}
|
||||
|
||||
if (updates_rank > 0) {
|
||||
|
@ -39,7 +39,7 @@ Status PopulateInfeedLayoutVector(const xla::Shape& shape,
|
||||
layouts->push_back(dim);
|
||||
}
|
||||
} else {
|
||||
layouts->insert(layouts->end(), xla::ShapeUtil::Rank(shape), -1);
|
||||
layouts->insert(layouts->end(), shape.rank(), -1);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -55,7 +55,7 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
|
||||
" cannot be converted to a TensorShape");
|
||||
}
|
||||
*tensor_shape = TensorShape();
|
||||
for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) {
|
||||
for (int i = 0; i < shape.rank(); ++i) {
|
||||
tensor_shape->AddDim(shape.dimensions(i));
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -55,7 +55,7 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
|
||||
|
||||
xla::XlaOp input_max = xla::Reduce(input, init_value, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
std::vector<int64> broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1);
|
||||
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
||||
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
||||
// Compute a mask that has 1s for elements equal to the maximum.
|
||||
|
@ -178,7 +178,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
// Converts an int32 or int64 scalar literal to an int64.
|
||||
static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
|
||||
int64* out) {
|
||||
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
|
||||
if (literal.shape().rank() != 0) {
|
||||
return errors::InvalidArgument("value is not a scalar");
|
||||
}
|
||||
if (literal.shape().element_type() == xla::S32) {
|
||||
@ -194,7 +194,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
|
||||
// Converts an float32 or float64 scalar literal to a float64.
|
||||
static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
|
||||
double* out) {
|
||||
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
|
||||
if (literal.shape().rank() != 0) {
|
||||
return errors::InvalidArgument("value is not a scalar");
|
||||
}
|
||||
if (literal.shape().element_type() == xla::F32) {
|
||||
@ -228,7 +228,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
|
||||
// Converts an int32 or int64 1D literal to an int64 vector.
|
||||
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
|
||||
std::vector<int64>* out) {
|
||||
if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
|
||||
if (literal.shape().rank() != 1) {
|
||||
return errors::InvalidArgument("value is not 1D");
|
||||
}
|
||||
int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
|
||||
|
@ -117,7 +117,7 @@ XlaOp Any(XlaOp predicates) {
|
||||
XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
|
||||
TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
|
||||
builder->GetShape(predicates));
|
||||
std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
|
||||
std::vector<int64> all_dimensions(predicates_shape.rank());
|
||||
std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
|
||||
return Reduce(predicates, f, logical_or, all_dimensions);
|
||||
});
|
||||
|
@ -54,7 +54,7 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = ShapeUtil::Rank(a_shape);
|
||||
const int n_dims = a_shape.rank();
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
@ -144,7 +144,7 @@ XlaOp Cholesky(XlaOp a, int64 block_size,
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int ndims = ShapeUtil::Rank(a_shape);
|
||||
const int ndims = a_shape.rank();
|
||||
if (ndims < 2) {
|
||||
return InvalidArgument(
|
||||
"Argument to Cholesky must have rank >= 2; shape was %s",
|
||||
|
@ -41,7 +41,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
@ -68,7 +68,7 @@ XlaOp TriangleMask(XlaOp x, int diagonal) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
@ -99,12 +99,12 @@ XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
|
||||
|
||||
// Check that both tensors have the same number of dimensions. There must be
|
||||
// at least two (the batch dimensions can be empty).
|
||||
if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) {
|
||||
if (x_shape.rank() != y_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Arguments to BatchDot have different ranks: %s vs. %s",
|
||||
ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape));
|
||||
}
|
||||
const int ndims = ShapeUtil::Rank(x_shape);
|
||||
const int ndims = x_shape.rank();
|
||||
if (ndims < 2) {
|
||||
return InvalidArgument(
|
||||
"Arguments to BatchDot must have rank >= 2: got %d", ndims);
|
||||
@ -169,7 +169,7 @@ XlaOp TransposeInMinorDims(XlaOp x) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
std::vector<int64> permutation(n_dims);
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
|
@ -154,7 +154,7 @@ struct QRBlockResult {
|
||||
StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = ShapeUtil::Rank(a_shape);
|
||||
const int num_dims = a_shape.rank();
|
||||
if (num_dims < 2) {
|
||||
return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
|
||||
a_shape.ToString());
|
||||
@ -325,7 +325,7 @@ StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = ShapeUtil::Rank(a_shape);
|
||||
const int num_dims = a_shape.rank();
|
||||
if (num_dims < 2) {
|
||||
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
|
||||
a_shape.ToString());
|
||||
|
@ -26,7 +26,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
auto major_dims = AsInt64Slice(shape.dimensions())
|
||||
.subspan(
|
||||
@ -55,7 +55,7 @@ XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
|
||||
std::vector<int32> start_as_int32(start.begin(), start.end());
|
||||
auto start_constant = ConstantR1<int32>(builder, start_as_int32);
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
TF_ASSIGN_OR_RETURN(Shape start_constant_shape,
|
||||
builder->GetShape(start_constant));
|
||||
const int64 start_length =
|
||||
@ -70,7 +70,7 @@ XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
const int64 n_minor_dims = start.size();
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
std::vector<int64> padded_start(n_dims, 0);
|
||||
@ -94,7 +94,7 @@ XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span<const XlaOp> starts) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
auto zero = Reshape(ConstantR0<int32>(builder, 0), {1});
|
||||
std::vector<XlaOp> padded_starts(n_dims, zero);
|
||||
for (int i = 0; i < starts.size(); ++i) {
|
||||
@ -111,7 +111,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_dims = shape.rank();
|
||||
int64 n_minor_dims = starts.size();
|
||||
TF_RET_CHECK(n_minor_dims == sizes.size());
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
|
@ -38,7 +38,7 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
|
||||
int ndims = ShapeUtil::Rank(shape);
|
||||
int ndims = shape.rank();
|
||||
int64 n = ShapeUtil::GetDimension(shape, -1);
|
||||
int64 num_blocks = n / block_size;
|
||||
|
||||
@ -262,7 +262,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
int64 ndims = ShapeUtil::Rank(a_shape);
|
||||
int64 ndims = a_shape.rank();
|
||||
int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
int64 num_blocks = n / block_size + (n % block_size != 0);
|
||||
int64 m_dim = (left_side) ? -1 : -2;
|
||||
@ -356,13 +356,13 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
|
||||
if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) {
|
||||
if (a_shape.rank() != b_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Arguments to TriangularSolve have shapes with different ranks: "
|
||||
"%s vs. %s",
|
||||
ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
|
||||
}
|
||||
const int64 ndims = ShapeUtil::Rank(a_shape);
|
||||
const int64 ndims = a_shape.rank();
|
||||
if (ndims < 2) {
|
||||
return InvalidArgument(
|
||||
"Arguments to TriangularSolve was rank %d but must have rank >= 2.",
|
||||
|
@ -50,7 +50,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
|
||||
CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
|
||||
CHECK_EQ(tile_shape.rank(), 1);
|
||||
std::vector<int64> dimensions(1, num_tiles);
|
||||
*result.mutable_tile_shape() = tile_shape.ToProto();
|
||||
auto& tile_dimension =
|
||||
|
@ -343,7 +343,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
|
||||
CHECK(ShapeUtil::IsScalar(operand_shape) ||
|
||||
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape));
|
||||
operand_shape.rank() == output_shape.rank());
|
||||
Shape broadcast_shape =
|
||||
ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type());
|
||||
|
||||
@ -355,7 +355,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
||||
// Do explicit broadcast for degenerate broadcast.
|
||||
std::vector<int64> broadcast_dimensions;
|
||||
std::vector<int64> reshaped_dimensions;
|
||||
for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) {
|
||||
for (int i = 0; i < operand_shape.rank(); i++) {
|
||||
if (operand_shape.dimensions(i) == output_shape.dimensions(i)) {
|
||||
broadcast_dimensions.push_back(i);
|
||||
reshaped_dimensions.push_back(operand_shape.dimensions(i));
|
||||
@ -398,8 +398,8 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
binop, lhs_shape, rhs_shape, broadcast_dimensions));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
const int64 lhs_rank = ShapeUtil::Rank(lhs_shape);
|
||||
const int64 rhs_rank = ShapeUtil::Rank(rhs_shape);
|
||||
const int64 lhs_rank = lhs_shape.rank();
|
||||
const int64 rhs_rank = rhs_shape.rank();
|
||||
|
||||
XlaOp updated_lhs = lhs;
|
||||
XlaOp updated_rhs = rhs;
|
||||
@ -413,8 +413,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
for (int64 size : shape.dimensions()) {
|
||||
to_size.push_back(size);
|
||||
}
|
||||
for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape);
|
||||
from_dim++) {
|
||||
for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
|
||||
int64 to_dim = broadcast_dimensions[from_dim];
|
||||
to_size[to_dim] = from_shape.dimensions(from_dim);
|
||||
}
|
||||
@ -563,10 +562,10 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
|
||||
// output, so to append dimensions on the left the instruction's dimensions
|
||||
// should just be the n highest dimension numbers of the output shape where
|
||||
// n is the number of input dimensions.
|
||||
const int64 operand_rank = ShapeUtil::Rank(operand_shape);
|
||||
const int64 operand_rank = operand_shape.rank();
|
||||
std::vector<int64> dimensions(operand_rank);
|
||||
for (int i = 0; i < operand_rank; ++i) {
|
||||
dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank;
|
||||
dimensions[i] = i + shape.rank() - operand_rank;
|
||||
}
|
||||
return InDimBroadcast(shape, operand, dimensions);
|
||||
});
|
||||
@ -639,10 +638,10 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
|
||||
int64 limit_index, int64 stride, int64 dimno) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
|
||||
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
|
||||
std::vector<int64> starts(shape.rank(), 0);
|
||||
std::vector<int64> limits(shape.dimensions().begin(),
|
||||
shape.dimensions().end());
|
||||
std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
|
||||
std::vector<int64> strides(shape.rank(), 1);
|
||||
starts[dimno] = start_index;
|
||||
limits[dimno] = limit_index;
|
||||
strides[dimno] = stride;
|
||||
@ -780,7 +779,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
|
||||
VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
|
||||
|
||||
std::vector<int64> new_sizes;
|
||||
for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
|
||||
for (int i = 0; i < original_shape.rank(); ++i) {
|
||||
if (i <= dimensions.front() || i > dimensions.back()) {
|
||||
new_sizes.push_back(original_shape.dimensions(i));
|
||||
} else {
|
||||
@ -915,13 +914,13 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||
Status XlaBuilder::VerifyConvolution(
|
||||
const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers) const {
|
||||
if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
|
||||
if (lhs_shape.rank() != rhs_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Convolution arguments must have same number of "
|
||||
"dimensions. Got: %s and %s",
|
||||
ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
|
||||
}
|
||||
int num_dims = ShapeUtil::Rank(lhs_shape);
|
||||
int num_dims = lhs_shape.rank();
|
||||
if (num_dims < 2) {
|
||||
return InvalidArgument(
|
||||
"Convolution expects argument arrays with >= 3 dimensions. "
|
||||
@ -1582,7 +1581,7 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
if (dimension == -1) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
|
||||
dimension = ShapeUtil::Rank(keys_shape) - 1;
|
||||
dimension = keys_shape.rank() - 1;
|
||||
}
|
||||
instr.add_dimensions(dimension);
|
||||
std::vector<XlaOp> operands{keys};
|
||||
@ -1652,12 +1651,12 @@ XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
Shape output_shape(instr.shape());
|
||||
const int64 output_rank = ShapeUtil::Rank(output_shape);
|
||||
const int64 output_rank = output_shape.rank();
|
||||
AddCalledComputation(computation, &instr);
|
||||
std::vector<XlaOp> new_operands(operands.begin(), operands.end());
|
||||
for (XlaOp& new_operand : new_operands) {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand));
|
||||
const int64 rank = ShapeUtil::Rank(shape);
|
||||
const int64 rank = shape.rank();
|
||||
if (rank != output_rank) {
|
||||
TF_ASSIGN_OR_RETURN(new_operand,
|
||||
InDimBroadcast(output_shape, new_operand, {}));
|
||||
@ -1866,7 +1865,7 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
std::vector<int64> all_dimnos(ShapeUtil::Rank(operand_shape));
|
||||
std::vector<int64> all_dimnos(operand_shape.rank());
|
||||
std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
|
||||
return Reduce(operand, init_value, computation, all_dimnos);
|
||||
});
|
||||
|
@ -141,7 +141,7 @@ namespace xla {
|
||||
|
||||
/* static */ bool IndexUtil::IndexInBounds(const Shape& shape,
|
||||
absl::Span<const int64> index) {
|
||||
int64 rank = ShapeUtil::Rank(shape);
|
||||
int64 rank = shape.rank();
|
||||
if (rank != index.size()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -211,19 +211,19 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
}
|
||||
|
||||
if (layout.format() == DENSE) {
|
||||
if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) {
|
||||
if (layout.minor_to_major_size() != shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"layout minor_to_major field contains %d elements, "
|
||||
"but shape is rank %d: {%s}; shape: %s",
|
||||
layout.minor_to_major_size(), ShapeUtil::Rank(shape),
|
||||
layout.minor_to_major_size(), shape.rank(),
|
||||
absl::StrJoin(layout.minor_to_major(), ", "),
|
||||
shape.ShortDebugString());
|
||||
}
|
||||
|
||||
std::vector<bool> dimensions_in_layout(ShapeUtil::Rank(shape), false);
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) {
|
||||
std::vector<bool> dimensions_in_layout(shape.rank(), false);
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
int64 dim = layout.minor_to_major(i);
|
||||
if (dim < 0 || dim >= ShapeUtil::Rank(shape)) {
|
||||
if (dim < 0 || dim >= shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"layout minor_to_major field has out-of-bounds value: %s",
|
||||
HumanString(layout));
|
||||
@ -376,7 +376,7 @@ Status CopyLayoutInternal(const Shape& src, Shape* dst) {
|
||||
}
|
||||
} else {
|
||||
if (src.has_layout()) {
|
||||
if (ShapeUtil::Rank(src) != ShapeUtil::Rank(*dst)) {
|
||||
if (src.rank() != dst->rank()) {
|
||||
return InvalidArgument("cannot copy layout from shape: ranks differs");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -410,7 +410,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
}
|
||||
return true;
|
||||
} else if (ShapeUtil::IsArray(lhs)) {
|
||||
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
|
||||
return lhs.rank() == rhs.rank() &&
|
||||
LayoutUtil::Equal(lhs.layout(), rhs.layout());
|
||||
} else {
|
||||
// Layouts of non-array and non-tuple shapes is ignored.
|
||||
|
@ -129,7 +129,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
|
||||
new char[max_sparse_elements *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
|
||||
piece->set_sparse_indices(
|
||||
new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
|
||||
new SparseIndexArray(max_sparse_elements, shape.rank()));
|
||||
} else {
|
||||
piece->set_buffer(new char[piece->size_bytes()]);
|
||||
}
|
||||
@ -208,16 +208,15 @@ template <typename NativeT>
|
||||
Status MutableLiteralBase::CopySliceFromInternal(
|
||||
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
||||
absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
|
||||
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
|
||||
TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
|
||||
TF_RET_CHECK(src_literal.shape().rank() == src_base.size());
|
||||
TF_RET_CHECK(shape().rank() == dest_base.size());
|
||||
|
||||
auto linear_index = [](const Shape& shape,
|
||||
absl::Span<const int64> multi_index) {
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
|
||||
};
|
||||
|
||||
if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
|
||||
ShapeUtil::Rank(shape()) == 0) {
|
||||
if (src_literal.shape().rank() == 0 || shape().rank() == 0) {
|
||||
// If any of the two shapes are scalars, we can just call the StridedCopy()
|
||||
// directly, and we know we will be copying only one value.
|
||||
TF_RET_CHECK(copy_size.empty());
|
||||
@ -375,7 +374,7 @@ void CopyElementsBetween(absl::Span<NativeT> dest,
|
||||
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64> index(ShapeUtil::Rank(dest_shape));
|
||||
std::vector<int64> index(dest_shape.rank());
|
||||
do {
|
||||
dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
|
||||
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
|
||||
@ -392,7 +391,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
|
||||
memcpy(buffer(), src.buffer(), src.size_bytes());
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
|
||||
std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
|
||||
std::vector<int64> origin(subshape().rank(), 0);
|
||||
switch (subshape().element_type()) {
|
||||
#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
|
||||
case (XLA_T): \
|
||||
@ -563,7 +562,7 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
|
||||
|
||||
void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
|
||||
CHECK_EQ(shape().rank(), 1);
|
||||
CHECK_EQ(element_count(), values.bits());
|
||||
CHECK_EQ(shape().element_type(), PRED);
|
||||
for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
|
||||
@ -648,8 +647,7 @@ StatusOr<Literal> LiteralBase::Reshape(
|
||||
}
|
||||
Literal output;
|
||||
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
|
||||
output =
|
||||
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
|
||||
output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank()));
|
||||
} else {
|
||||
output = Clone();
|
||||
}
|
||||
@ -672,7 +670,7 @@ StatusOr<Literal> LiteralBase::Reshape(
|
||||
|
||||
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
||||
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
|
||||
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
|
||||
CHECK(IsPermutation(permutation, shape().rank()))
|
||||
<< "Given permutation is not a permutation of dimension numbers";
|
||||
// To transpose the array, we just permute the dimensions and layout, and
|
||||
// do a straight memory copy of the raw data set.
|
||||
@ -711,10 +709,10 @@ template <typename NativeT>
|
||||
Literal LiteralBase::SliceInternal(
|
||||
const Shape& result_shape, absl::Span<const int64> start_indices) const {
|
||||
Literal result_literal(result_shape);
|
||||
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
|
||||
DimensionVector new_indices(result_shape.rank());
|
||||
result_literal.EachCell<NativeT>(
|
||||
[&](absl::Span<const int64> indices, NativeT /*value*/) {
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
|
||||
for (int64 i = 0; i < result_shape.rank(); ++i) {
|
||||
new_indices[i] = indices[i] + start_indices[i];
|
||||
}
|
||||
NativeT value = Get<NativeT>(new_indices);
|
||||
@ -728,7 +726,7 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
|
||||
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
|
||||
|
||||
DimensionVector result_dimensions;
|
||||
for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
|
||||
for (int64 dnum = 0; dnum < shape().rank(); ++dnum) {
|
||||
CHECK_GE(start_indices[dnum], 0);
|
||||
CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
|
||||
<< "dnum = " << dnum;
|
||||
@ -1056,7 +1054,7 @@ void SparseArrayToStringHelper(const LiteralBase& literal,
|
||||
pieces->push_back(ShapeToString(print_layout, subshape));
|
||||
}
|
||||
pieces->push_back("{");
|
||||
int64 rank = ShapeUtil::Rank(subshape);
|
||||
int64 rank = subshape.rank();
|
||||
int64 num_elements = literal.sparse_element_count();
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
if (i > 0) {
|
||||
@ -1079,7 +1077,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal,
|
||||
const ShapeIndex& shape_index, bool print_shape,
|
||||
bool print_layout, std::vector<string>* pieces) {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
|
||||
int64 rank = ShapeUtil::Rank(subshape);
|
||||
int64 rank = subshape.rank();
|
||||
|
||||
std::function<void(absl::Span<const int64> dimensions, std::vector<int64>*)>
|
||||
to_string_recursive = [&](absl::Span<const int64> dimensions,
|
||||
@ -1433,7 +1431,7 @@ StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
|
||||
template <typename NativeT>
|
||||
bool LiteralBase::Piece::EqualElementsInternal(
|
||||
const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
|
||||
if (multi_index->size() == ShapeUtil::Rank(subshape())) {
|
||||
if (multi_index->size() == subshape().rank()) {
|
||||
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
|
||||
}
|
||||
for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
|
||||
@ -1722,7 +1720,7 @@ bool LiteralBase::IsR1Iota() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(shape()) != 1) {
|
||||
if (shape().rank() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1932,14 +1930,12 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
|
||||
if (LayoutUtil::IsSparseArray(subshape())) {
|
||||
// Compute the number of elements (indices) in the sparse shape and reserve
|
||||
// the necessary space in spare_indices.
|
||||
TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0)
|
||||
<< "Scalar shapes cannot be sparse";
|
||||
TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0)
|
||||
TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse";
|
||||
TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0)
|
||||
<< "Unexpected number of indices in proto ("
|
||||
<< proto.sparse_indices_size() << ") for shape of rank "
|
||||
<< ShapeUtil::Rank(subshape());
|
||||
const int64 index_count =
|
||||
proto.sparse_indices_size() / ShapeUtil::Rank(subshape());
|
||||
<< subshape().rank();
|
||||
const int64 index_count = proto.sparse_indices_size() / subshape().rank();
|
||||
sparse_indices()->Resize(index_count);
|
||||
|
||||
// Copy the indices from the proto into the SparseIndexArray object.
|
||||
@ -2065,7 +2061,7 @@ int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
|
||||
|
||||
string LiteralBase::GetR1U8AsString() const {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
|
||||
CHECK_EQ(shape().rank(), 1);
|
||||
CHECK_EQ(shape().element_type(), U8);
|
||||
return string(absl::bit_cast<const char*>(data<uint8>().data()),
|
||||
ShapeUtil::ElementsIn(shape()));
|
||||
|
@ -961,7 +961,7 @@ void MutableLiteralBase::AppendSparseElement(
|
||||
Piece& p = piece(shape_index);
|
||||
const Shape& subshape = p.subshape();
|
||||
CHECK(LayoutUtil::IsSparseArray(subshape));
|
||||
int64 rank = ShapeUtil::Rank(subshape);
|
||||
int64 rank = subshape.rank();
|
||||
CHECK_EQ(multi_index.size(), rank);
|
||||
int64 last_element = p.sparse_indices()->index_count();
|
||||
CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
|
||||
@ -977,7 +977,7 @@ void LiteralBase::EachCell(
|
||||
if (ShapeUtil::IsZeroElementArray(shape())) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
|
||||
std::vector<int64> indices(shape().rank(), 0);
|
||||
do {
|
||||
per_cell(indices, Get<NativeT>(indices));
|
||||
} while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
|
||||
@ -986,7 +986,7 @@ void LiteralBase::EachCell(
|
||||
template <typename NativeT>
|
||||
inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
|
||||
CHECK_EQ(shape().rank(), 1);
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
|
||||
CHECK_EQ(shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
@ -998,7 +998,7 @@ template <typename NativeT>
|
||||
void MutableLiteralBase::PopulateR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 2);
|
||||
CHECK_EQ(shape().rank(), 2);
|
||||
CHECK_EQ(shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
|
||||
@ -1024,7 +1024,7 @@ void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(shape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
|
||||
CHECK_EQ(shape().rank(), values.num_dimensions());
|
||||
for (int dim = 0; dim < values.num_dimensions(); ++dim) {
|
||||
CHECK_EQ(values.dim(dim), shape().dimensions(dim));
|
||||
}
|
||||
@ -1053,7 +1053,7 @@ void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values,
|
||||
bool sort) {
|
||||
CHECK(LayoutUtil::IsSparseArray(shape()));
|
||||
int rank = ShapeUtil::Rank(shape());
|
||||
int rank = shape().rank();
|
||||
CHECK_EQ(indices.rank(), rank);
|
||||
int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
|
||||
CHECK_LE(indices.max_indices(), max_elements);
|
||||
@ -1077,7 +1077,7 @@ template <typename NativeT, typename FnType>
|
||||
Status MutableLiteralBase::PopulateInternal(const FnType& generator,
|
||||
bool parallel) {
|
||||
const Shape& this_shape = shape();
|
||||
const int64 rank = ShapeUtil::Rank(this_shape);
|
||||
const int64 rank = this_shape.rank();
|
||||
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
|
||||
TF_RET_CHECK(this_shape.element_type() ==
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
|
@ -463,7 +463,7 @@ class NearComparator {
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::vector<int64> multi_index(ShapeUtil::Rank(actual_.shape()), 0);
|
||||
std::vector<int64> multi_index(actual_.shape().rank(), 0);
|
||||
CompareLiteralsSlow(0, &multi_index);
|
||||
}
|
||||
|
||||
@ -777,7 +777,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
|
||||
}
|
||||
}
|
||||
} else if (ShapeUtil::IsArray(expected)) {
|
||||
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
|
||||
if (expected.rank() != actual.rank()) {
|
||||
return InvalidArgument("want rank of %s got rank of %s",
|
||||
ShapeUtil::HumanString(expected),
|
||||
ShapeUtil::HumanString(actual));
|
||||
|
@ -132,7 +132,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
|
||||
PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
|
||||
}
|
||||
} else {
|
||||
int rank = ShapeUtil::Rank(shape);
|
||||
int rank = shape.rank();
|
||||
dimensions = PyTuple_New(rank);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
PyTuple_SET_ITEM(dimensions, i,
|
||||
@ -354,7 +354,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
||||
}
|
||||
return tuple;
|
||||
} else {
|
||||
int rank = ShapeUtil::Rank(literal.shape());
|
||||
int rank = literal.shape().rank();
|
||||
std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long*
|
||||
for (int i = 0; i < rank; i++) {
|
||||
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
|
||||
|
@ -552,7 +552,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
Literal result_literal =
|
||||
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
|
||||
|
||||
CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
|
||||
CHECK_EQ(result_literal.shape().rank(), 4);
|
||||
auto result =
|
||||
absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
|
||||
result_literal.shape().dimensions(1),
|
||||
|
@ -251,7 +251,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Reshapes an instruction to rank 1 if it is not already rank 1.
|
||||
HloInstruction* Flatten(HloInstruction* hlo) {
|
||||
if (ShapeUtil::Rank(hlo->shape()) == 1) {
|
||||
if (hlo->shape().rank() == 1) {
|
||||
return hlo;
|
||||
}
|
||||
return computation_->AddInstruction(HloInstruction::CreateReshape(
|
||||
@ -687,7 +687,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
|
||||
return Status::OK();
|
||||
}
|
||||
PaddingConfig padding_config;
|
||||
for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) {
|
||||
for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) {
|
||||
auto padding_config_dim = padding_config.add_dimensions();
|
||||
padding_config_dim->set_edge_padding_high(0);
|
||||
padding_config_dim->set_edge_padding_low(0);
|
||||
@ -754,7 +754,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
|
||||
}
|
||||
|
||||
// If a literal is an increasing sequence from zero, replace it with an iota.
|
||||
if (ShapeUtil::Rank(constant->shape()) == 1 &&
|
||||
if (constant->shape().rank() == 1 &&
|
||||
ShapeUtil::ElementsIn(constant->shape()) > 1 &&
|
||||
constant->literal().IsR1Iota()) {
|
||||
return ReplaceWithNewInstruction(
|
||||
@ -930,9 +930,9 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
|
||||
return -1;
|
||||
};
|
||||
|
||||
const int64 dot_rank = ShapeUtil::Rank(dot->shape());
|
||||
const int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
|
||||
const int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
|
||||
const int64 dot_rank = dot->shape().rank();
|
||||
const int64 rhs_rank = rhs->shape().rank();
|
||||
const int64 lhs_rank = lhs->shape().rank();
|
||||
const auto& dnums = dot->dot_dimension_numbers();
|
||||
if (dnums.rhs_contracting_dimensions_size() > 1) {
|
||||
return false;
|
||||
@ -1036,7 +1036,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
|
||||
// )
|
||||
if (lhs_rank == 1 ||
|
||||
(lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) {
|
||||
if (ShapeUtil::Rank(rhs->shape()) == 1) {
|
||||
if (rhs->shape().rank() == 1) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
|
||||
multiply(Flatten(lhs), rhs), 0))));
|
||||
@ -1449,8 +1449,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
dot->shape().element_type() != BF16) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 ||
|
||||
ShapeUtil::Rank(dot->shape()) > 2) {
|
||||
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
|
||||
dot->shape().rank() > 2) {
|
||||
if (options_.enable_dot_strength_reduction() &&
|
||||
!options_.is_layout_sensitive()) {
|
||||
TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status());
|
||||
@ -1732,8 +1732,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
|
||||
|
||||
// A degenerate broadcast that has the same input and output rank can be
|
||||
// converted into a transpose.
|
||||
if (ShapeUtil::Rank(broadcast->shape()) ==
|
||||
ShapeUtil::Rank(operand->shape()) &&
|
||||
if (broadcast->shape().rank() == operand->shape().rank() &&
|
||||
ShapeUtil::ElementsIn(broadcast->shape()) ==
|
||||
ShapeUtil::ElementsIn(operand->shape())) {
|
||||
VLOG(10) << "transform broadcast(X) -> transpose(X) where "
|
||||
@ -1888,7 +1887,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
if (HasInteriorPadding(pad->padding_config())) {
|
||||
PaddingConfig padding_config = pad->padding_config();
|
||||
bool cleared_interior_padding = false;
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
|
||||
for (int64 i = 0; i < pad->shape().rank(); ++i) {
|
||||
if (padding_config.dimensions(i).interior_padding() > 0 &&
|
||||
pad->operand(0)->shape().dimensions(i) == 1) {
|
||||
cleared_interior_padding = true;
|
||||
@ -2276,7 +2275,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
|
||||
if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
|
||||
VLOG(10) << "Trying to simplify scalar slice of concat";
|
||||
// Only do this for R1, there's no chance of this being useful otherwise.
|
||||
if (ShapeUtil::Rank(slice->shape()) != 1) {
|
||||
if (slice->shape().rank() != 1) {
|
||||
VLOG(10) << "Not folding, slice is not rank 1";
|
||||
return false;
|
||||
}
|
||||
@ -2326,7 +2325,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
|
||||
return false;
|
||||
}
|
||||
HloInstruction* new_slice_operand = reshape->mutable_operand(0);
|
||||
int64 slice_rank = ShapeUtil::Rank(slice->shape());
|
||||
int64 slice_rank = slice->shape().rank();
|
||||
std::vector<int64> sliced_dims;
|
||||
for (int64 i = 0; i < slice_rank; ++i) {
|
||||
if (slice->slice_starts(i) != 0 ||
|
||||
@ -2338,7 +2337,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
|
||||
if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
|
||||
slice->slice_starts(0) == 0) {
|
||||
const Shape& new_slice_shape = new_slice_operand->shape();
|
||||
const int64 rank = ShapeUtil::Rank(new_slice_shape);
|
||||
const int64 rank = new_slice_shape.rank();
|
||||
std::vector<int64> new_slice_starts(rank, 0);
|
||||
std::vector<int64> new_slice_stides(rank, 1);
|
||||
std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
|
||||
@ -2456,8 +2455,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
// A Transpose feeding a reduce can simply permute the reduction dimensions
|
||||
// field if the output of the reduce is a vector or scalar. Higher ranked
|
||||
// result may require a transpose of the output.
|
||||
if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
|
||||
arg->opcode() == HloOpcode::kTranspose) {
|
||||
if (reduce->shape().rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) {
|
||||
auto transpose_dimensions = arg->dimensions();
|
||||
std::vector<int64> new_reduce_dimensions;
|
||||
for (auto dim : dimensions) {
|
||||
@ -2516,8 +2514,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
|
||||
arg->shape());
|
||||
std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true);
|
||||
std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false);
|
||||
std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
|
||||
std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
|
||||
for (auto dim : dimensions) {
|
||||
arg_dim_in_output[dim] = false;
|
||||
}
|
||||
@ -2542,7 +2540,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
}
|
||||
}
|
||||
std::vector<int64> new_reduce_dimensions;
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) {
|
||||
for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
|
||||
if (dimensions_not_to_reduce.count(i) == 0) {
|
||||
new_reduce_dimensions.push_back(i);
|
||||
}
|
||||
@ -2779,7 +2777,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
// Carry out the folding of the pad into reduce_window.
|
||||
VLOG(10) << "Folding pad into reduce-window.";
|
||||
Window new_window = window;
|
||||
const int64 rank = ShapeUtil::Rank(reduce_window->shape());
|
||||
const int64 rank = reduce_window->shape().rank();
|
||||
TF_RET_CHECK(pad_config.dimensions_size() == rank);
|
||||
TF_RET_CHECK(window.dimensions_size() == rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
@ -2862,7 +2860,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
|
||||
// - Use this as the indices parameter of scatter, and set updates
|
||||
// of the scatter to be a reshaped 'values' parameter of sort (adding
|
||||
// 'rank' many 1 dimensions at the end).
|
||||
int64 rank = ShapeUtil::Rank(operand->shape());
|
||||
int64 rank = operand->shape().rank();
|
||||
Shape extended_shape = operand->shape();
|
||||
extended_shape.add_dimensions(1);
|
||||
extended_shape.mutable_layout()->add_minor_to_major(rank);
|
||||
|
@ -3498,7 +3498,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
|
||||
|
||||
// Create the reduce-window.
|
||||
Window window;
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
|
||||
for (int64 i = 0; i < pad->shape().rank(); ++i) {
|
||||
auto* dim = window.add_dimensions();
|
||||
dim->set_size(1);
|
||||
dim->set_padding_low(10);
|
||||
@ -3584,7 +3584,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
|
||||
|
||||
// Create the reduce-window.
|
||||
Window window;
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
|
||||
for (int64 i = 0; i < pad->shape().rank(); ++i) {
|
||||
auto* dim = window.add_dimensions();
|
||||
dim->set_size(1);
|
||||
dim->set_padding_low(10);
|
||||
|
@ -123,7 +123,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
auto elements_per_feature_u32 = add_instruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
|
||||
for (int64 i = 0; i < operand->shape().rank(); ++i) {
|
||||
if (i == feature_index) {
|
||||
continue;
|
||||
}
|
||||
@ -229,7 +229,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
|
||||
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
|
||||
for (int64 i = 0; i < operand_shape.rank(); ++i) {
|
||||
if (i != feature_index) {
|
||||
dimensions_without_feature.push_back(i);
|
||||
}
|
||||
@ -357,7 +357,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
|
||||
for (int64 i = 0; i < operand_shape.rank(); ++i) {
|
||||
if (i != feature_index) {
|
||||
dimensions_without_feature.push_back(i);
|
||||
}
|
||||
@ -494,7 +494,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) {
|
||||
for (int64 i = 0; i < activation_shape.rank(); ++i) {
|
||||
if (i != feature_index) {
|
||||
dimensions_without_feature.push_back(i);
|
||||
}
|
||||
|
@ -1550,7 +1550,7 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
||||
}
|
||||
|
||||
// Return whether the given shape is rank 2.
|
||||
static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; }
|
||||
static bool IsRank2(const Shape& shape) { return shape.rank() == 2; }
|
||||
|
||||
// In a gemm operation where output = lhs * rhs, check whether the given shapes
|
||||
// are valid for the operation.
|
||||
|
@ -535,7 +535,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) {
|
||||
higher_dimensions *= normalized_keys_shape.dimensions(i);
|
||||
}
|
||||
int64 lower_dimensions = 1;
|
||||
for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
|
||||
for (int64 i = normalized_keys_shape.rank() - 1;
|
||||
i > physical_dimension_to_sort; --i) {
|
||||
lower_dimensions *= normalized_keys_shape.dimensions(i);
|
||||
}
|
||||
@ -779,8 +779,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
const auto init_value = select_and_scatter->operand(2);
|
||||
const Window& window = select_and_scatter->window();
|
||||
PrimitiveType operand_element_type = operand->shape().element_type();
|
||||
const int64 rank = ShapeUtil::Rank(operand->shape());
|
||||
CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
|
||||
const int64 rank = operand->shape().rank();
|
||||
CHECK_EQ(rank, source->shape().rank());
|
||||
CHECK_EQ(rank, window.dimensions_size());
|
||||
|
||||
// TODO(b/31410564): Implement dilation for select-and-scatter.
|
||||
|
@ -173,7 +173,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
|
||||
// Find out the new dynamic dimension after reduce.
|
||||
int64 dimensions_not_reduced_count = 0;
|
||||
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
|
||||
for (int i = 0; i < operand->shape().rank(); ++i) {
|
||||
if (dimension == i) {
|
||||
parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
|
||||
dynamic_size);
|
||||
@ -207,7 +207,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||
result_dim_mapping[i] = current_result_dims++;
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) {
|
||||
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
|
||||
if (!absl::c_linear_search(
|
||||
dimension_numbers.lhs_contracting_dimensions(), i)) {
|
||||
if (operand_index == 0) {
|
||||
@ -217,7 +217,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||
}
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) {
|
||||
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
|
||||
if (!absl::c_linear_search(
|
||||
dimension_numbers.rhs_contracting_dimensions(), i) &&
|
||||
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
|
||||
|
@ -121,10 +121,11 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const {
|
||||
dynamic_dimension.parameter_index));
|
||||
TF_RET_CHECK(
|
||||
dynamic_dimension.dimension <
|
||||
ShapeUtil::Rank(ShapeUtil::GetSubshape(
|
||||
ShapeUtil::GetSubshape(
|
||||
entry->parameter_instruction(dynamic_dimension.parameter_num)
|
||||
->shape(),
|
||||
dynamic_dimension.parameter_index)));
|
||||
dynamic_dimension.parameter_index)
|
||||
.rank());
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
@ -1327,9 +1327,9 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
|
||||
|
||||
// If implicit broadcast is needed, the source dimensions that are broadcast
|
||||
// have index 0.
|
||||
CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
|
||||
CHECK_EQ(operand_shape.rank(), hlo.shape().rank());
|
||||
llvm_ir::IrArray::Index source_index(target_index.GetType());
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
|
||||
source_index.push_back(target_index[i]);
|
||||
} else {
|
||||
@ -1750,7 +1750,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
|
||||
const llvm_ir::IrArray::Index& index) {
|
||||
// Emit IR to read dynamic start indices from hlo->operand(1).
|
||||
const HloInstruction* input_hlo = hlo->operand(0);
|
||||
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
|
||||
const int64 rank = input_hlo->shape().rank();
|
||||
// Use the same index type for all tensor accesses in the same kernel.
|
||||
llvm::Type* index_type = index.GetType();
|
||||
llvm_ir::IrArray::Index slice_start_index(index_type, rank);
|
||||
@ -1893,7 +1893,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
|
||||
const HloInstruction* update_hlo = hlo->operand(1);
|
||||
const HloInstruction* start_hlo = hlo->operand(2);
|
||||
// Calculate slice start/end indices.
|
||||
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
|
||||
const int64 rank = input_hlo->shape().rank();
|
||||
llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank);
|
||||
llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank);
|
||||
// Slice intersection gathers (ANDs) conditions on all ranks for which
|
||||
@ -2225,7 +2225,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
auto* iota = Cast<HloIotaInstruction>(hlo);
|
||||
PrimitiveType element_type = iota->shape().element_type();
|
||||
IrArray::Index elem_index =
|
||||
ShapeUtil::Rank(iota->shape()) > 1
|
||||
iota->shape().rank() > 1
|
||||
? target_index.SourceIndexOfBroadcast(
|
||||
iota->shape(),
|
||||
ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
|
@ -50,10 +50,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
|
||||
auto* zero = comp->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
|
||||
|
||||
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
|
||||
PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank());
|
||||
|
||||
bool added_padding = false;
|
||||
for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) {
|
||||
for (int64 dim = 0; dim < shape.rank(); ++dim) {
|
||||
if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -219,7 +219,7 @@ bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
|
||||
Window new_backward_conv_window = backward_conv->window();
|
||||
// input_padding_config is the config of the kPad to be inserted.
|
||||
PaddingConfig input_padding_config =
|
||||
MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
|
||||
MakeNoPaddingConfig(input->shape().rank());
|
||||
ConvolutionDimensionNumbers backward_conv_dnums =
|
||||
backward_conv->convolution_dimension_numbers();
|
||||
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
|
||||
|
@ -315,8 +315,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
|
||||
dim_nums.rhs_batch_dimensions_size());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
|
||||
ShapeUtil::Rank(output_shape_));
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
|
||||
|
||||
int64 row_dim = dim_nums.lhs_batch_dimensions_size();
|
||||
int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
|
||||
|
@ -43,14 +43,14 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
||||
const Layout* max_rank_layout;
|
||||
for (HloInstruction* param : params) {
|
||||
if (ShapeUtil::IsArray(param->shape()) &&
|
||||
ShapeUtil::Rank(param->shape()) > max_rank) {
|
||||
max_rank = ShapeUtil::Rank(param->shape());
|
||||
param->shape().rank() > max_rank) {
|
||||
max_rank = param->shape().rank();
|
||||
max_rank_layout = ¶m->shape().layout();
|
||||
}
|
||||
}
|
||||
return absl::c_all_of(params, [&](HloInstruction* param) {
|
||||
return (!ShapeUtil::IsArray(param->shape())) ||
|
||||
(ShapeUtil::Rank(param->shape()) < max_rank) ||
|
||||
(param->shape().rank() < max_rank) ||
|
||||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
|
||||
});
|
||||
}
|
||||
|
@ -196,9 +196,9 @@ Status GpuLayoutAssignment::AddBackendConstraints(
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
|
||||
dim_nums.rhs_batch_dimensions_size());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
|
||||
ShapeUtil::Rank(instruction->shape()));
|
||||
instruction->shape().rank());
|
||||
for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
|
||||
CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2);
|
||||
CHECK_LT(batch_dim, instruction->shape().rank() - 2);
|
||||
}
|
||||
|
||||
// Set both inputs and the output to default layout.
|
||||
@ -215,11 +215,11 @@ Status GpuLayoutAssignment::AddBackendConstraints(
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(output_shape, instruction));
|
||||
} else if (instruction->opcode() == HloOpcode::kSort &&
|
||||
ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) {
|
||||
instruction->operand(0)->shape().rank() > 1) {
|
||||
// Make sure that all the operands and the output(s) have the same layout.
|
||||
Shape keys_shape = instruction->operand(0)->shape();
|
||||
Layout keys_layout =
|
||||
LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape));
|
||||
LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
|
||||
for (int64 i = 0; i < instruction->operand_count(); ++i) {
|
||||
Shape shape = instruction->operand(i)->shape();
|
||||
*shape.mutable_layout() = keys_layout;
|
||||
|
@ -40,7 +40,7 @@ namespace {
|
||||
|
||||
// Return whether the given shape is rank 2 excluding the batch dimensions.
|
||||
bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
|
||||
return ShapeUtil::Rank(shape) == batch_dimensions_size + 2;
|
||||
return shape.rank() == batch_dimensions_size + 2;
|
||||
}
|
||||
|
||||
// In a gemm operation where output = lhs * rhs, check whether the given shapes
|
||||
|
@ -698,8 +698,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
const auto* source = select_and_scatter->operand(1);
|
||||
const Window& window = select_and_scatter->window();
|
||||
PrimitiveType operand_element_type = operand->shape().element_type();
|
||||
const int64 rank = ShapeUtil::Rank(operand->shape());
|
||||
CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
|
||||
const int64 rank = operand->shape().rank();
|
||||
CHECK_EQ(rank, source->shape().rank());
|
||||
CHECK_EQ(rank, window.dimensions_size());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
|
||||
@ -1015,7 +1015,7 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
int64 raw_window_multidim_idx = 0;
|
||||
std::vector<llvm::Value*> input_window_multidim;
|
||||
std::vector<int64> input_window_bounds;
|
||||
for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) {
|
||||
for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
|
||||
if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
|
||||
input_window_bounds.push_back(1); // Trivial dimension.
|
||||
input_window_multidim.push_back(index.GetConstantWithIndexType(0));
|
||||
@ -1027,12 +1027,11 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
++raw_window_multidim_idx;
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape()));
|
||||
DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
|
||||
|
||||
// Insert a 1 dimension at the end if index_vector_dim requests one.
|
||||
Shape scatter_indices_shape = scatter_indices->shape();
|
||||
if (dim_numbers.index_vector_dim() ==
|
||||
ShapeUtil::Rank(scatter_indices_shape)) {
|
||||
if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
|
||||
scatter_indices_shape.add_dimensions(1);
|
||||
scatter_indices_shape.mutable_layout()->add_minor_to_major(
|
||||
dim_numbers.index_vector_dim());
|
||||
@ -3191,7 +3190,7 @@ Status AreFusedReductionOutputsConsistent(
|
||||
// dimensions from minor to major.
|
||||
DimensionVector GetDimensionsToKeepMinorToMajor(
|
||||
const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
|
||||
DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0);
|
||||
DimensionVector input_dims(input_shape.rank(), 0);
|
||||
absl::c_iota(input_dims, 0);
|
||||
DimensionVector input_dims_to_keep;
|
||||
for (int input_dim : input_dims) {
|
||||
@ -3231,7 +3230,7 @@ std::tuple<int64, int64, int64> GetReductionToVectorDimensions(
|
||||
if (input_dims_to_keep_minor_to_major.empty()) {
|
||||
return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
|
||||
}
|
||||
DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0);
|
||||
DimensionVector input_dims(input_shape.rank(), 0);
|
||||
absl::c_iota(input_dims, 0);
|
||||
absl::Span<const int64> minor_to_major =
|
||||
LayoutUtil::MinorToMajor(input_shape);
|
||||
|
@ -189,8 +189,7 @@ StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
|
||||
for (const HloInstruction* operand : operands) {
|
||||
CHECK_EQ(computation, operand->parent());
|
||||
operand_shapes.push_back(&operand->shape());
|
||||
max_operand_rank =
|
||||
std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
|
||||
max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
|
||||
}
|
||||
std::vector<int64> map_dims(max_operand_rank);
|
||||
std::iota(map_dims.begin(), map_dims.end(), 0);
|
||||
@ -207,7 +206,7 @@ StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
|
||||
HloOpcode binary_opcode,
|
||||
HloModule* module) {
|
||||
DCHECK_NE(nullptr, module);
|
||||
std::vector<int64> all_dims(ShapeUtil::Rank(operand->shape()));
|
||||
std::vector<int64> all_dims(operand->shape().rank());
|
||||
std::iota(all_dims.begin(), all_dims.end(), 0);
|
||||
|
||||
auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
|
||||
|
@ -443,7 +443,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
|
||||
// concatenate dimensions of the operands taking part of the operation.
|
||||
const Shape& reference_shape = operands[0]->shape();
|
||||
CHECK(ShapeUtil::IsArray(reference_shape));
|
||||
const int64 rank = ShapeUtil::Rank(reference_shape);
|
||||
const int64 rank = reference_shape.rank();
|
||||
const int64 concat_dim = concatenate->dimensions()[0];
|
||||
CHECK_GE(concat_dim, 0);
|
||||
CHECK_LT(concat_dim, rank);
|
||||
@ -1036,11 +1036,9 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
|
||||
Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
|
||||
const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
|
||||
|
||||
TF_RET_CHECK(broadcast->dimensions().size() ==
|
||||
ShapeUtil::Rank(operand.shape()))
|
||||
TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
|
||||
<< "broadcast dimensions is of size: " << broadcast->dimensions().size()
|
||||
<< " and rank of operand_to_broadcast is: "
|
||||
<< ShapeUtil::Rank(operand.shape());
|
||||
<< " and rank of operand_to_broadcast is: " << operand.shape().rank();
|
||||
// Checks that operand's dimensions are the same as the broadcast's
|
||||
// dimensions along the dimensions to be broadcasted.
|
||||
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
|
||||
@ -1251,7 +1249,7 @@ template <typename KeyType, typename ValueType>
|
||||
StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
|
||||
const Literal& keys_literal,
|
||||
const Literal& values_literal) {
|
||||
auto rank = ShapeUtil::Rank(keys_literal.shape());
|
||||
auto rank = keys_literal.shape().rank();
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
|
||||
<< "Sort keys and values must have the same dimensions";
|
||||
|
@ -1005,8 +1005,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
CHECK_GE(num_spatial_dims, 0);
|
||||
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
|
||||
|
||||
const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
|
||||
const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
|
||||
const auto lhs_rank = lhs_shape.rank();
|
||||
const auto rhs_rank = rhs_shape.rank();
|
||||
|
||||
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
|
||||
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
|
||||
@ -1175,8 +1175,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
const auto& dnums = dot->dot_dimension_numbers();
|
||||
|
||||
const int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
|
||||
const int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
|
||||
const int64 lhs_rank = lhs->shape().rank();
|
||||
const int64 rhs_rank = rhs->shape().rank();
|
||||
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
|
||||
@ -1238,8 +1238,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
const auto& dnums = dot->dot_dimension_numbers();
|
||||
|
||||
const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
|
||||
const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
|
||||
const auto lhs_rank = lhs->shape().rank();
|
||||
const auto rhs_rank = rhs->shape().rank();
|
||||
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
|
||||
@ -1329,7 +1329,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
CHECK(ShapeUtil::IsArray(pad->operand(0)->shape()));
|
||||
// Padding value must be scalar.
|
||||
CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
|
||||
CHECK_EQ(pad->operand(0)->shape().rank(),
|
||||
pad->padding_config().dimensions_size());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||
@ -1352,9 +1352,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& evaluated_operand =
|
||||
parent_->GetEvaluatedLiteralFor(pad->operand(0));
|
||||
|
||||
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
|
||||
0);
|
||||
std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
|
||||
std::vector<int64> input_index(evaluated_operand.shape().rank(), 0);
|
||||
std::vector<int64> target_index(result.shape().rank(), 0);
|
||||
|
||||
// Loop through each element of the operand, assign them to the
|
||||
// corresponding index of the resulting padded literal.
|
||||
@ -1609,7 +1608,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
|
||||
int64 sort_dim = sort->dimensions(0);
|
||||
int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
|
||||
int64 rank = ShapeUtil::Rank(keys->shape());
|
||||
int64 rank = keys->shape().rank();
|
||||
if (rank == 0) {
|
||||
// Nothing to sort.
|
||||
parent_->evaluated_[sort] = keys_literal.Clone();
|
||||
@ -1868,7 +1867,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
|
||||
|
||||
int64 rank = ShapeUtil::Rank(operand_literal.shape());
|
||||
int64 rank = operand_literal.shape().rank();
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
DimensionVector source_index(rank, 0);
|
||||
@ -1980,7 +1979,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
operand->shape().element_type(), window_dimension_sizes);
|
||||
|
||||
DimensionVector window_index(window.dimensions_size());
|
||||
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
|
||||
DimensionVector operand_index(operand_literal.shape().rank());
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
Literal result(reduce_window->shape());
|
||||
@ -2411,7 +2410,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
<< " but is inferred to be: "
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
|
||||
const int64 rank = ShapeUtil::Rank(operand->shape());
|
||||
const int64 rank = operand->shape().rank();
|
||||
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
auto func = [&](absl::Span<const int64> out_index) {
|
||||
DimensionVector operand_index(rank);
|
||||
@ -2648,12 +2647,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
auto result = LiteralUtil::CreateR1<NativeT>(data);
|
||||
|
||||
if (ShapeUtil::Rank(iota->shape()) > 1) {
|
||||
if (iota->shape().rank() > 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[iota],
|
||||
result.Broadcast(iota->shape(), {iota->iota_dimension()}));
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
|
||||
TF_RET_CHECK(iota->shape().rank() == 1);
|
||||
parent_->evaluated_[iota] = std::move(result);
|
||||
}
|
||||
|
||||
@ -2683,7 +2682,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
//
|
||||
// This lets you calculate LI given the multidimensional indices in any order.
|
||||
static DimensionVector MakeDimMultipliers(const Shape& shape) {
|
||||
DimensionVector v(ShapeUtil::Rank(shape));
|
||||
DimensionVector v(shape.rank());
|
||||
int64 scale = 1;
|
||||
for (auto dim : LayoutUtil::MinorToMajor(shape)) {
|
||||
v[dim] = scale;
|
||||
@ -2700,7 +2699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Shape& window_shape, const Window& window, const Shape& base_shape,
|
||||
const absl::Span<const int64>& window_count_index,
|
||||
const std::function<void(const std::vector<int64>&)>& f) {
|
||||
const int64 rank = ShapeUtil::Rank(base_shape);
|
||||
const int64 rank = base_shape.rank();
|
||||
DimensionVector window_index(rank);
|
||||
std::fill(window_index.begin(), window_index.end(), 0);
|
||||
do {
|
||||
@ -2767,7 +2766,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& start_indices_literal) {
|
||||
auto result = operand_literal.Clone();
|
||||
auto start_indices_typed = start_indices_literal.data<IndexT>();
|
||||
const auto rank = ShapeUtil::Rank(result.shape());
|
||||
const auto rank = result.shape().rank();
|
||||
std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
// Clamp the update start indices so the slice is in-bounds w.r.t the
|
||||
|
@ -1039,7 +1039,7 @@ HloInstruction::CreateBroadcastSequence(
|
||||
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
|
||||
adder) {
|
||||
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
|
||||
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
|
||||
operand->shape().rank() == output_shape.rank());
|
||||
Shape broadcast_shape = ShapeUtil::ChangeElementType(
|
||||
output_shape, operand->shape().element_type());
|
||||
// Do explicit broadcast for scalar.
|
||||
@ -1055,7 +1055,7 @@ HloInstruction::CreateBroadcastSequence(
|
||||
// Do explicit broadcast for degenerate broadcast.
|
||||
std::vector<int64> broadcast_dimensions;
|
||||
std::vector<int64> reshaped_dimensions;
|
||||
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
|
||||
for (int i = 0; i < operand->shape().rank(); i++) {
|
||||
if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
|
||||
broadcast_dimensions.push_back(i);
|
||||
reshaped_dimensions.push_back(operand->shape().dimensions(i));
|
||||
|
@ -734,7 +734,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape,
|
||||
AppendComputation(map_computation);
|
||||
// TODO(b/65689298) Remove code below once Map is generalized to accept
|
||||
// arbitrary map dimensions.
|
||||
dimensions_.resize(ShapeUtil::Rank(shape));
|
||||
dimensions_.resize(shape.rank());
|
||||
std::iota(dimensions_.begin(), dimensions_.end(), 0);
|
||||
}
|
||||
|
||||
|
@ -1980,7 +1980,7 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
|
||||
}
|
||||
|
||||
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
|
||||
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
|
||||
const tensorflow::int64 rank = shape.rank();
|
||||
// Create a literal with the given shape in default layout.
|
||||
*literal = LiteralUtil::CreateFromDimensions(
|
||||
shape.element_type(), AsInt64Slice(shape.dimensions()));
|
||||
@ -2145,7 +2145,7 @@ template <typename LiteralNativeT>
|
||||
bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
|
||||
std::vector<tensorflow::int64> index;
|
||||
|
||||
tensorflow::int64 rank = ShapeUtil::Rank(shape);
|
||||
tensorflow::int64 rank = shape.rank();
|
||||
|
||||
*literal = Literal(shape);
|
||||
|
||||
|
@ -30,7 +30,7 @@ HloSharding HloSharding::AssignDevice(int64 device_id) {
|
||||
}
|
||||
|
||||
HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
|
||||
CHECK_EQ(1, ShapeUtil::Rank(input_shape));
|
||||
CHECK_EQ(1, input_shape.rank());
|
||||
CHECK_GT(num_tiles, 1);
|
||||
std::vector<int64> dimensions(1, num_tiles);
|
||||
Array<int64> assignment(dimensions);
|
||||
@ -340,7 +340,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
}
|
||||
|
||||
// The tile assignment tensor must have the same rank as the input.
|
||||
if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) {
|
||||
if (shape.rank() != tile_assignment_.num_dimensions()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Number of tile assignment dimensions is different to the input rank. "
|
||||
"sharding=",
|
||||
|
@ -349,7 +349,7 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
|
||||
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
|
||||
auto* iota = Cast<HloIotaInstruction>(instruction);
|
||||
const int64 rank = ShapeUtil::Rank(iota->shape());
|
||||
const int64 rank = iota->shape().rank();
|
||||
if (rank == 0) {
|
||||
return InternalError("Iota does not support scalars.");
|
||||
}
|
||||
@ -397,13 +397,11 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
|
||||
const Shape& operand_shape = broadcast->operand(0)->shape();
|
||||
// Check for mixed precision.
|
||||
TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
|
||||
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
|
||||
broadcast->dimensions().size());
|
||||
for (int64 operand_dimension = 0;
|
||||
operand_dimension < ShapeUtil::Rank(operand_shape);
|
||||
TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
|
||||
for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
|
||||
++operand_dimension) {
|
||||
int64 output_dimension = broadcast->dimensions()[operand_dimension];
|
||||
TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) &&
|
||||
TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
|
||||
output_dimension >= 0 &&
|
||||
(broadcast->shape().dimensions(output_dimension) ==
|
||||
operand_shape.dimensions(operand_dimension)))
|
||||
@ -524,8 +522,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) {
|
||||
int64 max_operand_rank = 0;
|
||||
for (const HloInstruction* operand : map->operands()) {
|
||||
operand_shapes.push_back(&operand->shape());
|
||||
max_operand_rank =
|
||||
std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
|
||||
max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
|
||||
}
|
||||
// TODO(b/65689298) Remove code below once Map is generalized to accept
|
||||
// arbitrary map dimensions.
|
||||
@ -1271,11 +1268,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
|
||||
// op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
|
||||
// or ComputationLowerer::Visit()
|
||||
TF_RET_CHECK(broadcast->dimensions().size() ==
|
||||
ShapeUtil::Rank(broadcast->operand(0)->shape()))
|
||||
broadcast->operand(0)->shape().rank())
|
||||
<< "Broadcast HLO (" << broadcast->ToShortString()
|
||||
<< ") has invalid number of dimensions: "
|
||||
<< broadcast->dimensions().size()
|
||||
<< " != " << ShapeUtil::Rank(broadcast->operand(0)->shape());
|
||||
<< " != " << broadcast->operand(0)->shape().rank();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1376,7 +1373,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
|
||||
for (HloInstruction* operand : instruction->operands()) {
|
||||
const Shape& operand_shape = operand->shape();
|
||||
if (LayoutUtil::IsDenseArray(operand_shape) &&
|
||||
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) {
|
||||
operand_shape.rank() == result_shape.rank()) {
|
||||
const Layout& operand_layout = operand_shape.layout();
|
||||
TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
|
||||
<< "Instruction shouldn't change layouts "
|
||||
|
@ -1002,7 +1002,7 @@ bool CanFoldDotIntoIndexedArray(
|
||||
absl::Span<const int64> contracting_dims,
|
||||
absl::Span<const int64> batch_dims) {
|
||||
absl::optional<int64> non_contracting_non_batch_dim =
|
||||
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
|
||||
GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
|
||||
contracting_dims, batch_dims);
|
||||
if (!non_contracting_non_batch_dim.has_value()) {
|
||||
VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
|
||||
@ -1015,7 +1015,7 @@ bool CanFoldDotIntoIndexedArray(
|
||||
return false;
|
||||
}
|
||||
|
||||
int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape());
|
||||
int64 indexed_array_rank = indexed_array->shape().rank();
|
||||
if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
|
||||
// This restriction can be lifted by inserting reshape nodes.
|
||||
VLOG(3) << tag
|
||||
@ -1043,7 +1043,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
|
||||
int64 lhs_rank = lhs->shape().rank();
|
||||
DotDimensionNumbers new_dim_numbers = dim_numbers;
|
||||
new_dim_numbers.set_lhs_contracting_dimensions(
|
||||
0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
|
||||
@ -1078,7 +1078,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
|
||||
int64 rhs_rank = rhs->shape().rank();
|
||||
|
||||
DotDimensionNumbers new_dim_numbers = dim_numbers;
|
||||
new_dim_numbers.set_rhs_contracting_dimensions(
|
||||
|
@ -991,8 +991,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
||||
CHECK(ShapeUtil::IsArray(instruction->shape()));
|
||||
CHECK(ShapeUtil::IsArray(operand->shape()));
|
||||
if (!ShapeUtil::IsScalar(operand->shape()) &&
|
||||
ShapeUtil::Rank(operand->shape()) ==
|
||||
ShapeUtil::Rank(instruction->shape()) &&
|
||||
operand->shape().rank() == instruction->shape().rank() &&
|
||||
!instruction_can_change_layout_func_(instruction)) {
|
||||
// Propagate the result layout to the operand layout if the instruction
|
||||
// requires the same layout out for the result and the operand.
|
||||
@ -1012,7 +1011,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
||||
// operations. For similar reasons, if the operand and output have the same
|
||||
// rank, try to match the operand's layout to the output.
|
||||
if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
|
||||
ShapeUtil::Rank(instruction->shape()) == 1) {
|
||||
instruction->shape().rank() == 1) {
|
||||
// Don't assign a layout in case of R1 -> effective R1 reshape.
|
||||
return nullptr;
|
||||
}
|
||||
@ -1026,7 +1025,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
||||
if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
|
||||
return absl::make_unique<Layout>(operand_shape.layout());
|
||||
}
|
||||
if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
|
||||
if (operand_shape.rank() == output_shape.rank()) {
|
||||
*operand_shape.mutable_layout() = output_layout;
|
||||
if (ShapeUtil::ReshapeIsBitcast(operand_shape,
|
||||
output_shape_with_layout)) {
|
||||
@ -1045,7 +1044,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kTranspose) {
|
||||
// Pick the operand layout that makes the transpose a bitcast.
|
||||
int64 rank = ShapeUtil::Rank(instruction->shape());
|
||||
int64 rank = instruction->shape().rank();
|
||||
std::vector<int64> new_minor_to_major(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
int64 output_dim = LayoutUtil::Minor(output_layout, i);
|
||||
@ -1070,7 +1069,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
|
||||
ShapeUtil::IsArray(operand->shape()));
|
||||
|
||||
if (!ShapeUtil::IsScalar(operand->shape()) &&
|
||||
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
|
||||
operand->shape().rank() == user->shape().rank() &&
|
||||
!instruction_can_change_layout_func_(user)) {
|
||||
// Assign users the same layout as the operand.
|
||||
return absl::make_unique<Layout>(operand_layout);
|
||||
@ -1083,7 +1082,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
|
||||
// reshape is a bitcast when using the same layout. This may avoid copy
|
||||
// operations. For similar reasons, if the operand and output have the same
|
||||
// rank, try to match the outputs's layout to the operand.
|
||||
if (ShapeUtil::Rank(operand->shape()) == 1 &&
|
||||
if (operand->shape().rank() == 1 &&
|
||||
ShapeUtil::TrueRank(user->shape()) == 1) {
|
||||
// Don't assign a layout in case of R1 -> effective R1 reshape.
|
||||
return nullptr;
|
||||
@ -1098,7 +1097,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
|
||||
if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
|
||||
return absl::make_unique<Layout>(output_shape.layout());
|
||||
}
|
||||
if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
|
||||
if (operand->shape().rank() == output_shape.rank()) {
|
||||
*output_shape.mutable_layout() = operand_layout;
|
||||
if (ShapeUtil::ReshapeIsBitcast(output_shape,
|
||||
operand_shape_with_layout)) {
|
||||
@ -1117,7 +1116,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
|
||||
|
||||
if (user->opcode() == HloOpcode::kTranspose) {
|
||||
// Pick the user layout that makes the transpose a bitcast.
|
||||
int64 rank = ShapeUtil::Rank(user->shape());
|
||||
int64 rank = user->shape().rank();
|
||||
std::vector<int64> new_minor_to_major(rank);
|
||||
auto inverse_dimensions = InversePermutation(user->dimensions());
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
@ -1273,7 +1272,7 @@ Status LayoutAssignment::PropagateOperandConstraint(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 operand_rank = ShapeUtil::Rank(operand->shape());
|
||||
int64 operand_rank = operand->shape().rank();
|
||||
if (operand_rank <= 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1288,7 +1287,7 @@ Status LayoutAssignment::PropagateOperandConstraint(
|
||||
continue;
|
||||
}
|
||||
const HloInstruction* sibling = user->operand(operand_no);
|
||||
const int64 sibling_rank = ShapeUtil::Rank(sibling->shape());
|
||||
const int64 sibling_rank = sibling->shape().rank();
|
||||
if (sibling_rank <= 1) {
|
||||
continue;
|
||||
}
|
||||
@ -1320,13 +1319,13 @@ Status LayoutAssignment::PropagateOperandConstraint(
|
||||
if (ShapeUtil::IsTuple(subshape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (ShapeUtil::Rank(subshape) <= 1) {
|
||||
if (subshape.rank() <= 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Assign the right layout to input fusion of higher rank reduce
|
||||
// operations.
|
||||
if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) {
|
||||
if (subshape.rank() != operand->shape().rank()) {
|
||||
return Status::OK();
|
||||
}
|
||||
// TODO(b/67641796): Are there cases except fusion that use this code
|
||||
@ -1357,7 +1356,7 @@ Status LayoutAssignment::PropagateOperandConstraint(
|
||||
if (ShapeUtil::IsTuple(subshape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (ShapeUtil::Rank(subshape) <= 1) {
|
||||
if (subshape.rank() <= 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1402,7 +1401,7 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands(
|
||||
if (!instruction_can_change_layout_func_(instruction)) {
|
||||
// Copy the layout to the operand.
|
||||
if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) &&
|
||||
ShapeUtil::Rank(operand->shape()) ==
|
||||
operand->shape().rank() ==
|
||||
LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
|
||||
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
|
||||
buffer_constraint.layout(), instruction, operand_no,
|
||||
@ -2101,7 +2100,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
/* static */
|
||||
bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
|
||||
if (ShapeUtil::IsArray(shape)) {
|
||||
return ShapeUtil::Rank(shape) <= 1;
|
||||
return shape.rank() <= 1;
|
||||
}
|
||||
return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
|
||||
return IsAtMostRank1(subshape);
|
||||
|
@ -528,8 +528,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
|
||||
for (int64 operand_no = 0; operand_no < instruction->operand_count();
|
||||
++operand_no) {
|
||||
const HloInstruction* operand = instruction->operand(operand_no);
|
||||
if (ShapeUtil::Rank(instruction->shape()) !=
|
||||
ShapeUtil::Rank(operand->shape())) {
|
||||
if (instruction->shape().rank() != operand->shape().rank()) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
|
||||
|
@ -44,7 +44,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
|
||||
const Shape& output_shape = output_array.GetShape();
|
||||
|
||||
// Read start indices from start_indices_generator.
|
||||
const int64 rank = ShapeUtil::Rank(output_shape);
|
||||
const int64 rank = output_shape.rank();
|
||||
IrArray::Index start_index(b->getInt64Ty(), rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
IrArray::Index dim_index({b->getInt64(i)});
|
||||
|
@ -61,7 +61,7 @@ void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
|
||||
|
||||
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
||||
llvm::IRBuilder<>* b)
|
||||
: multidim_(ShapeUtil::Rank(shape)),
|
||||
: multidim_(shape.rank()),
|
||||
linear_(linear),
|
||||
layout_(shape.layout()),
|
||||
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
||||
@ -120,7 +120,7 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
|
||||
if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) {
|
||||
DCHECK(depth == 1 || depth == 0) << depth;
|
||||
} else {
|
||||
DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString();
|
||||
DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString();
|
||||
}
|
||||
}
|
||||
|
||||
@ -137,12 +137,12 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
||||
const Shape& output_shape, const Shape& input_shape,
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
const auto& target_index = *this;
|
||||
CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
|
||||
CHECK_EQ(target_index.size(), output_shape.rank());
|
||||
std::vector<std::pair<int64, int64>> common_factors =
|
||||
CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||
AsInt64Slice(output_shape.dimensions()));
|
||||
std::vector<llvm::Value*> source_multidim_index(
|
||||
ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_));
|
||||
input_shape.rank(), llvm::UndefValue::get(index_type_));
|
||||
// We compute the source indices in each common factor from only the target
|
||||
// indices in the same common factor.
|
||||
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
|
||||
@ -257,7 +257,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
|
||||
const Shape& shape, const Shape& operand_shape,
|
||||
absl::Span<const int64> dimension_mapping,
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
int64 rank = ShapeUtil::Rank(operand_shape);
|
||||
int64 rank = operand_shape.rank();
|
||||
std::vector<llvm::Value*> source_index(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
source_index[i] = multidim_[dimension_mapping[i]];
|
||||
@ -271,7 +271,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
|
||||
// The other dimensions can be masked out with a div and a mod operation.
|
||||
std::vector<int64> logical_to_physical =
|
||||
LayoutUtil::MakeLogicalToPhysical(shape.layout());
|
||||
int64 output_rank = ShapeUtil::Rank(shape);
|
||||
int64 output_rank = shape.rank();
|
||||
// The minimum physical dimension that is broadcasted.
|
||||
int64 min_broadcasted_dimension = output_rank;
|
||||
// The maximum physical dimension that is broadcasted.
|
||||
@ -348,7 +348,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
|
||||
// over higher-rank arrays.
|
||||
return base_ptr_;
|
||||
}
|
||||
CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_));
|
||||
CHECK_EQ(index.size(), shape_->rank());
|
||||
|
||||
if (index.LinearValidOnShape(*shape_)) {
|
||||
llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
|
||||
|
@ -235,7 +235,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
|
||||
|
||||
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
|
||||
absl::string_view suffix) {
|
||||
std::vector<int64> dimensions(ShapeUtil::Rank(shape));
|
||||
std::vector<int64> dimensions(shape.rank());
|
||||
std::iota(dimensions.begin(), dimensions.end(), 0);
|
||||
return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
|
||||
}
|
||||
|
@ -322,7 +322,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
|
||||
// comparisons).
|
||||
|
||||
const Shape& keys_shape = keys_array.GetShape();
|
||||
int64 rank = ShapeUtil::Rank(keys_shape);
|
||||
int64 rank = keys_shape.rank();
|
||||
int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
|
||||
std::vector<int64> dimensions_in_iteration_order(rank);
|
||||
std::vector<int64> iteration_order_to_logical_order(rank);
|
||||
|
@ -831,7 +831,7 @@ class ShapePatternRankImpl {
|
||||
explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
|
||||
|
||||
bool Match(const ::xla::Shape* shape, MatchOption option) const {
|
||||
if (ShapeUtil::Rank(*shape) != rank_) {
|
||||
if (shape->rank() != rank_) {
|
||||
if (rank_ == 0) {
|
||||
EXPLAIN << "Shape is not a scalar";
|
||||
} else {
|
||||
|
@ -88,7 +88,7 @@ static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
|
||||
static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
|
||||
HloInstruction* updates, absl::Span<const int64> update_window_dims) {
|
||||
std::vector<int64> permutation;
|
||||
const int64 updates_rank = ShapeUtil::Rank(updates->shape());
|
||||
const int64 updates_rank = updates->shape().rank();
|
||||
permutation.reserve(updates_rank);
|
||||
|
||||
for (int64 i = 0; i < updates_rank; ++i) {
|
||||
|
@ -96,7 +96,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape,
|
||||
}
|
||||
|
||||
for (const Shape* element_shape : accumulator_subshapes) {
|
||||
if (ShapeUtil::Rank(*element_shape) != 0) {
|
||||
if (element_shape->rank() != 0) {
|
||||
return InvalidArgument(
|
||||
"Reduction function must return a scalar or tuple of scalars but "
|
||||
"returns shape: %s",
|
||||
@ -160,10 +160,10 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
const Window& window,
|
||||
PrimitiveType element_type,
|
||||
bool allow_negative_padding) {
|
||||
if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
|
||||
if (window.dimensions_size() != base_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Window has dimension %d but base shape has dimension %d.",
|
||||
window.dimensions_size(), ShapeUtil::Rank(base_shape));
|
||||
window.dimensions_size(), base_shape.rank());
|
||||
}
|
||||
|
||||
std::vector<int64> output_dimensions(window.dimensions_size());
|
||||
@ -338,7 +338,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
if (arg_shapes.empty()) {
|
||||
return InvalidArgument("Concatenate expects at least one argument.");
|
||||
}
|
||||
if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
|
||||
if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
|
||||
return InvalidArgument("Concatenate dimension out of bounds: %d.",
|
||||
dimension);
|
||||
}
|
||||
@ -351,12 +351,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
element_type = arg_shape->element_type();
|
||||
continue;
|
||||
}
|
||||
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
|
||||
if (arg_shape->rank() != shape->rank()) {
|
||||
return InvalidArgument(
|
||||
"Cannot concatenate arrays with different ranks: %d (%s) vs %d "
|
||||
"(%s).",
|
||||
ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape),
|
||||
ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape));
|
||||
arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
|
||||
ShapeUtil::HumanString(*shape));
|
||||
}
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
|
||||
return InvalidArgument(
|
||||
@ -364,8 +364,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
PrimitiveType_Name(arg_shape->element_type()),
|
||||
PrimitiveType_Name(shape->element_type()));
|
||||
}
|
||||
for (int64 dimension_number = 0;
|
||||
dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
|
||||
for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
|
||||
++dimension_number) {
|
||||
if (arg_shape->dimensions(dimension_number) !=
|
||||
shape->dimensions(dimension_number)) {
|
||||
if (dimension_number == dimension) {
|
||||
@ -480,7 +480,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
return InvalidArgument(
|
||||
"Pad operation does not support non-scalar padding values.");
|
||||
}
|
||||
if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
|
||||
if (operand_shape.rank() != padding_config.dimensions_size()) {
|
||||
return InvalidArgument(
|
||||
"The rank of the operand and the padding configuration do not match: "
|
||||
"%s vs %s.",
|
||||
@ -500,7 +500,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
padding_config.ShortDebugString());
|
||||
}
|
||||
|
||||
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
|
||||
std::vector<int64> dimensions(operand_shape.rank());
|
||||
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
|
||||
const auto& p = padding_config.dimensions(i);
|
||||
dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
|
||||
@ -555,9 +555,9 @@ Status ValidateDotDimensionNumbers(
|
||||
absl::Span<const int64> rhs_batch_dimensions =
|
||||
AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
|
||||
|
||||
if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
|
||||
if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
|
||||
lhs_batch_dimensions) ||
|
||||
!dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
|
||||
!dims_in_range(rhs.rank(), rhs_contracting_dimensions,
|
||||
rhs_batch_dimensions)) {
|
||||
return InvalidArgument("A dimension number is out of range in Dot: %s.",
|
||||
dimension_numbers.DebugString());
|
||||
@ -583,12 +583,10 @@ Status ValidateDotDimensionNumbers(
|
||||
|
||||
// Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
|
||||
const int64 lhs_non_contracting_non_batch_dims =
|
||||
ShapeUtil::Rank(lhs) -
|
||||
dimension_numbers.lhs_contracting_dimensions_size() -
|
||||
lhs.rank() - dimension_numbers.lhs_contracting_dimensions_size() -
|
||||
dimension_numbers.lhs_batch_dimensions_size();
|
||||
const int64 rhs_non_contracting_non_batch_dims =
|
||||
ShapeUtil::Rank(rhs) -
|
||||
dimension_numbers.rhs_contracting_dimensions_size() -
|
||||
rhs.rank() - dimension_numbers.rhs_contracting_dimensions_size() -
|
||||
dimension_numbers.rhs_batch_dimensions_size();
|
||||
if (lhs_non_contracting_non_batch_dims < 0 ||
|
||||
lhs_non_contracting_non_batch_dims > 1 ||
|
||||
@ -637,7 +635,7 @@ Status ValidateDotDimensionNumbers(
|
||||
return fail("Element types do not match.");
|
||||
}
|
||||
|
||||
if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) {
|
||||
if ((lhs.rank() < 1) || (rhs.rank() < 1)) {
|
||||
return fail("Dot only supports rank 1 or above.");
|
||||
}
|
||||
|
||||
@ -686,12 +684,12 @@ Status ValidateDotDimensionNumbers(
|
||||
std::unordered_set<int64> rhs_batch_dims(
|
||||
dimension_numbers.rhs_batch_dimensions().begin(),
|
||||
dimension_numbers.rhs_batch_dimensions().end());
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
|
||||
for (int64 i = 0; i < lhs.rank(); i++) {
|
||||
if (i != lhs_contracting_dimension) {
|
||||
dimensions.push_back(lhs.dimensions(i));
|
||||
}
|
||||
}
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
|
||||
for (int64 i = 0; i < rhs.rank(); i++) {
|
||||
if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) {
|
||||
dimensions.push_back(rhs.dimensions(i));
|
||||
}
|
||||
@ -708,14 +706,14 @@ Status ValidateDotDimensionNumbers(
|
||||
ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
|
||||
TF_RET_CHECK(lhs.rank() == rhs.rank());
|
||||
|
||||
// The shapes have to be compatible. That is, if some dimension d has a
|
||||
// different size in the two shapes, one of them has to be 1 (a "degenerate"
|
||||
// dimension). In that case, the output shape has the non-1 dimension size
|
||||
// from the lhs/rhs pair in every index.
|
||||
std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
|
||||
std::vector<int64> output_dimensions(lhs.rank());
|
||||
for (int64 i = 0; i < lhs.rank(); ++i) {
|
||||
if (lhs.dimensions(i) == rhs.dimensions(i)) {
|
||||
output_dimensions[i] = lhs.dimensions(i);
|
||||
} else if (lhs.dimensions(i) == 1) {
|
||||
@ -743,13 +741,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
return InvalidArgument("Automatic shape inference not supported: %s and %s",
|
||||
ShapeUtil::HumanString(smaller_shape),
|
||||
ShapeUtil::HumanString(larger_shape));
|
||||
} else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
|
||||
} else if (broadcast_dimensions.size() != smaller_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Size of broadcast_dimensions has to match lower-rank operand's "
|
||||
"rank; "
|
||||
" lower-rank operand's rank is %d, size of broadcast_dimensions is "
|
||||
"%u.",
|
||||
ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
|
||||
smaller_shape.rank(), broadcast_dimensions.size());
|
||||
}
|
||||
|
||||
// broadcast_dimensions is a sequence of dimensions; its length is equal to
|
||||
@ -847,8 +845,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(rhs));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
|
||||
std::vector<int64> identity_dims(ShapeUtil::Rank(lhs));
|
||||
if (lhs.rank() == rhs.rank()) {
|
||||
std::vector<int64> identity_dims(lhs.rank());
|
||||
std::iota(identity_dims.begin(), identity_dims.end(), 0);
|
||||
if (!broadcast_dimensions.empty() &&
|
||||
broadcast_dimensions != identity_dims) {
|
||||
@ -865,15 +863,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
|
||||
if (lhs.rank() == rhs.rank()) {
|
||||
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
|
||||
} else {
|
||||
// Ranks do not match, so perform InDim broadcasting using
|
||||
// broadcast_dimensions. Scalar broadcasting is a special case of this.
|
||||
const Shape& larger_shape =
|
||||
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
|
||||
const Shape& smaller_shape =
|
||||
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
|
||||
const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
|
||||
const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
|
||||
|
||||
// After InDim broadcasting, perform degenerate dimensions broadcasting.
|
||||
TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
|
||||
@ -1162,12 +1158,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
|
||||
Status::OK());
|
||||
|
||||
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
|
||||
if (feature_index >= operand_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-training to be "
|
||||
"smaller than the rank of operand_shape; "
|
||||
"got feature_index %d, and rank %d.",
|
||||
feature_index, ShapeUtil::Rank(operand_shape));
|
||||
feature_index, operand_shape.rank());
|
||||
}
|
||||
|
||||
if (feature_index < 0) {
|
||||
@ -1177,25 +1173,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
feature_index);
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(operand_shape) < 1) {
|
||||
if (operand_shape.rank() < 1) {
|
||||
return InvalidArgument(
|
||||
"Expected the rank of operand to "
|
||||
"batch-norm-training to be at least 1; got %d.",
|
||||
ShapeUtil::Rank(operand_shape));
|
||||
operand_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(offset_shape) != 1) {
|
||||
if (offset_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Offset input of batch-norm-training must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(offset_shape));
|
||||
offset_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(scale_shape) != 1) {
|
||||
if (scale_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Scale input of batch-norm-training must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(scale_shape));
|
||||
scale_shape.rank());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
|
||||
@ -1272,12 +1268,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
|
||||
Status::OK());
|
||||
|
||||
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
|
||||
if (feature_index >= operand_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-inference to be "
|
||||
"smaller than the rank of operand_shape; "
|
||||
"got feature_index %d, and rank %d.",
|
||||
feature_index, ShapeUtil::Rank(operand_shape));
|
||||
feature_index, operand_shape.rank());
|
||||
}
|
||||
|
||||
if (feature_index < 0) {
|
||||
@ -1287,25 +1283,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
feature_index);
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(operand_shape) < 1) {
|
||||
if (operand_shape.rank() < 1) {
|
||||
return InvalidArgument(
|
||||
"Expected the rank of operand to "
|
||||
"batch-norm-inference to be at least 1; got %d.",
|
||||
ShapeUtil::Rank(operand_shape));
|
||||
operand_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(offset_shape) != 1) {
|
||||
if (offset_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Offset input of batch-norm-inference must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(offset_shape));
|
||||
offset_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(scale_shape) != 1) {
|
||||
if (scale_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Scale input of batch-norm-inference must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(scale_shape));
|
||||
scale_shape.rank());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
|
||||
@ -1417,41 +1413,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
|
||||
|
||||
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
|
||||
if (feature_index >= operand_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-grad to be "
|
||||
"smaller than the rank of operand_shape; "
|
||||
"got feature_index %d, and rank %d.",
|
||||
feature_index, ShapeUtil::Rank(operand_shape));
|
||||
feature_index, operand_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
|
||||
if (operand_shape.rank() != output_grad_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Expected operand_shape of batch-norm-grad to have the same rank as"
|
||||
" output_grad_shape; got rank(oprand_shape) %d, and"
|
||||
" rank(output_grad_shape) %d.",
|
||||
ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
|
||||
operand_shape.rank(), output_grad_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(mean_shape) != 1) {
|
||||
if (mean_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Mean input of batch-norm-grad must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(mean_shape));
|
||||
mean_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(scale_shape) != 1) {
|
||||
if (scale_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Scale input of batch-norm-grad must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(scale_shape));
|
||||
scale_shape.rank());
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(var_shape) != 1) {
|
||||
if (var_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Var input of batch-norm-grad must have"
|
||||
" rank 1, but has rank %d.",
|
||||
ShapeUtil::Rank(var_shape));
|
||||
var_shape.rank());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
|
||||
@ -1538,7 +1534,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
// Verify operand_shape and output_grad_shape have same bounds.
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
|
||||
for (int64 i = 0; i < operand_shape.rank(); ++i) {
|
||||
if (ShapeUtil::GetDimension(operand_shape, i) !=
|
||||
ShapeUtil::GetDimension(output_grad_shape, i)) {
|
||||
return InvalidArgument(
|
||||
@ -1603,12 +1599,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
const int num_dims = num_spatial_dims + 2;
|
||||
if (ShapeUtil::Rank(lhs) != num_dims) {
|
||||
if (lhs.rank() != num_dims) {
|
||||
return InvalidArgument(
|
||||
"The LHS argument to a convolution should have rank %d; lhs: %s.",
|
||||
num_dims, ShapeUtil::HumanString(lhs));
|
||||
}
|
||||
if (ShapeUtil::Rank(rhs) != num_dims) {
|
||||
if (rhs.rank() != num_dims) {
|
||||
return InvalidArgument(
|
||||
"The RHS argument to a convolution should have rank %d; rhs: %s.",
|
||||
num_dims, ShapeUtil::HumanString(rhs));
|
||||
@ -1853,12 +1849,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const Shape& shape, int64 split_dimension, int64 concat_dimension,
|
||||
int64 split_count) {
|
||||
TF_RET_CHECK(split_count > 0);
|
||||
if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) {
|
||||
if (split_dimension >= shape.rank() || split_dimension < 0) {
|
||||
return InvalidArgument(
|
||||
"AllToAll split_dimension %d is out-of-bounds in shape %s.",
|
||||
split_dimension, ShapeUtil::HumanString(shape));
|
||||
}
|
||||
if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) {
|
||||
if (concat_dimension >= shape.rank() || concat_dimension < 0) {
|
||||
return InvalidArgument(
|
||||
"AllToAll concat_dimension %d is out-of-bounds in shape %s.",
|
||||
concat_dimension, ShapeUtil::HumanString(shape));
|
||||
@ -1932,7 +1928,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
// doesn't matter which one we choose.
|
||||
const Shape& arg = *reduced_args[0];
|
||||
for (int64 dimension : dimensions_to_reduce) {
|
||||
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
|
||||
if (dimension >= arg.rank() || dimension < 0) {
|
||||
return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
|
||||
dimension, ShapeUtil::HumanString(arg));
|
||||
}
|
||||
@ -1949,7 +1945,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
|
||||
dimensions_to_reduce.end());
|
||||
std::vector<int64> new_dimensions;
|
||||
for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
|
||||
for (int i = 0; i < arg.rank(); ++i) {
|
||||
if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
|
||||
new_dimensions.push_back(arg.dimensions(i));
|
||||
}
|
||||
@ -2041,7 +2037,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
|
||||
const Shape& shape, int64 dimension) {
|
||||
if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) {
|
||||
if (dimension < 0 || dimension >= shape.rank()) {
|
||||
return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
|
||||
dimension);
|
||||
}
|
||||
@ -2083,10 +2079,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
starts.size(), strides.size()));
|
||||
}
|
||||
|
||||
if (starts.size() != ShapeUtil::Rank(arg)) {
|
||||
if (starts.size() != arg.rank()) {
|
||||
return InvalidArgument(
|
||||
"Slice index count does not match argument rank: %u vs %d.",
|
||||
starts.size(), ShapeUtil::Rank(arg));
|
||||
starts.size(), arg.rank());
|
||||
}
|
||||
|
||||
std::vector<int64> sizes;
|
||||
@ -2132,10 +2128,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(operand_shape),
|
||||
ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", "));
|
||||
|
||||
if (ShapeUtil::Rank(start_indices_shape) != 1) {
|
||||
if (start_indices_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Dynamic slice start indices of rank %d must be rank1.",
|
||||
ShapeUtil::Rank(start_indices_shape));
|
||||
start_indices_shape.rank());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
|
||||
@ -2144,18 +2140,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
const int64 start_num_dims = start_indices_shape.dimensions(0);
|
||||
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
|
||||
if (operand_shape.rank() != start_num_dims) {
|
||||
return InvalidArgument(
|
||||
"Dynamic slice start number of dimensions %d (%s) must match rank "
|
||||
"%d of slice input (%s).",
|
||||
start_num_dims, ShapeUtil::HumanString(start_indices_shape),
|
||||
ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
|
||||
operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
|
||||
}
|
||||
|
||||
if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
|
||||
if (slice_sizes.size() != operand_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Dynamic slice index count does not match argument rank: %u vs %d.",
|
||||
slice_sizes.size(), ShapeUtil::Rank(operand_shape));
|
||||
slice_sizes.size(), operand_shape.rank());
|
||||
}
|
||||
|
||||
for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
|
||||
@ -2193,10 +2189,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(start_indices_shape),
|
||||
ShapeUtil::HumanString(update_shape));
|
||||
|
||||
if (ShapeUtil::Rank(start_indices_shape) != 1) {
|
||||
if (start_indices_shape.rank() != 1) {
|
||||
return InvalidArgument(
|
||||
"Dynamic update slice start indices of rank %d must be rank1.",
|
||||
ShapeUtil::Rank(start_indices_shape));
|
||||
start_indices_shape.rank());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
|
||||
@ -2205,19 +2201,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
|
||||
const int64 start_num_dims = start_indices_shape.dimensions(0);
|
||||
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
|
||||
if (operand_shape.rank() != start_num_dims) {
|
||||
return InvalidArgument(
|
||||
"Dynamic update slice start number of dimensions %d (%s) must match "
|
||||
"rank %d of slice input (%s).",
|
||||
start_num_dims, ShapeUtil::HumanString(start_indices_shape),
|
||||
ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
|
||||
operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
|
||||
if (update_shape.rank() != operand_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Dynamic update slice update rank does not match argument rank: "
|
||||
"%d vs %d.",
|
||||
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
|
||||
update_shape.rank(), operand_shape.rank());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
|
||||
@ -2229,7 +2225,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
PrimitiveType_Name(update_shape.element_type()));
|
||||
}
|
||||
|
||||
for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
|
||||
for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
|
||||
const int64 input_dim_size = operand_shape.dimensions(dim);
|
||||
const int64 update_dim_size = update_shape.dimensions(dim);
|
||||
if (update_dim_size < 0) {
|
||||
@ -2255,7 +2251,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
return InvalidArgument("a dimension number is duplicated in reverse");
|
||||
}
|
||||
for (int64 dimension : dimensions) {
|
||||
if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
|
||||
if (dimension >= operand_shape.rank() || dimension < 0) {
|
||||
return InvalidArgument(
|
||||
"One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
|
||||
dimension, ShapeUtil::HumanString(operand_shape));
|
||||
@ -2397,8 +2393,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
|
||||
TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
|
||||
const int64 operand_rank = ShapeUtil::Rank(operand_shape);
|
||||
const int64 output_rank = ShapeUtil::Rank(output_shape);
|
||||
const int64 operand_rank = operand_shape.rank();
|
||||
const int64 output_rank = output_shape.rank();
|
||||
if (operand_rank > output_rank) {
|
||||
return InvalidArgument(
|
||||
"InDim style broadcast must be to an equal or higher ranked shape; "
|
||||
@ -2457,9 +2453,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(inferred_shape));
|
||||
}
|
||||
|
||||
std::vector<int64> indices(ShapeUtil::Rank(operand));
|
||||
std::vector<int64> indices(operand.rank());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
if (dimensions.size() != ShapeUtil::Rank(operand) ||
|
||||
if (dimensions.size() != operand.rank() ||
|
||||
!std::is_permutation(dimensions.begin(), dimensions.end(),
|
||||
indices.begin())) {
|
||||
return InvalidArgument(
|
||||
@ -2475,9 +2471,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const Shape& operand, absl::Span<const int64> dimensions) {
|
||||
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
|
||||
|
||||
std::vector<int64> indices(ShapeUtil::Rank(operand));
|
||||
std::vector<int64> indices(operand.rank());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
if (dimensions.size() != ShapeUtil::Rank(operand) ||
|
||||
if (dimensions.size() != operand.rank() ||
|
||||
!std::is_permutation(dimensions.begin(), dimensions.end(),
|
||||
indices.begin())) {
|
||||
return InvalidArgument(
|
||||
@ -2829,7 +2825,7 @@ Status ValidateScatterDimensionNumbers(
|
||||
"update_window_dims in scatter op must not repeat; got: %s.",
|
||||
StrJoin(dim_numbers.update_window_dims(), ", "));
|
||||
}
|
||||
const int64 updates_rank = ShapeUtil::Rank(updates_shape);
|
||||
const int64 updates_rank = updates_shape.rank();
|
||||
for (int64 window_dim : dim_numbers.update_window_dims()) {
|
||||
if (window_dim < 0 || window_dim >= updates_rank) {
|
||||
return InvalidArgument(
|
||||
@ -2863,10 +2859,10 @@ Status ValidateScatterDimensionNumbers(
|
||||
// Validate window size.
|
||||
auto window_size = dim_numbers.update_window_dims_size() +
|
||||
dim_numbers.inserted_window_dims_size();
|
||||
if (window_size != ShapeUtil::Rank(operand_shape)) {
|
||||
if (window_size != operand_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Scatter op has window of size %d; doesn't match operand of rank %d.",
|
||||
window_size, ShapeUtil::Rank(operand_shape));
|
||||
window_size, operand_shape.rank());
|
||||
}
|
||||
|
||||
// Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
|
||||
@ -2951,10 +2947,9 @@ Status ValidateScatterDimensionNumbers(
|
||||
|
||||
int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
|
||||
scatter_dim_numbers.update_window_dims_size();
|
||||
if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) {
|
||||
if (updates_shape.rank() != expected_updates_rank) {
|
||||
return InvalidArgument("Updates tensor must be of rank %d; got %d.",
|
||||
expected_updates_rank,
|
||||
ShapeUtil::Rank(updates_shape));
|
||||
expected_updates_rank, updates_shape.rank());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
|
||||
@ -2985,7 +2980,7 @@ Status ValidateScatterDimensionNumbers(
|
||||
}
|
||||
|
||||
int64 scatter_dims_seen = 0;
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
|
||||
for (int64 i = 0; i < updates_shape.rank(); ++i) {
|
||||
bool is_update_window_dim =
|
||||
absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
|
||||
if (is_update_window_dim) {
|
||||
|
@ -45,7 +45,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
|
||||
auto& operand = *dot.operand(i);
|
||||
if (operand.IsRank2Transpose()) {
|
||||
operand_set.push_back(i);
|
||||
} else if (ShapeUtil::Rank(operand.shape()) != 2) {
|
||||
} else if (operand.shape().rank() != 2) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
@ -61,6 +61,12 @@ string Shape::ToString(bool print_layout) const {
|
||||
}
|
||||
}
|
||||
|
||||
int64 Shape::rank() const {
|
||||
CHECK(ShapeUtil::IsArray(*this))
|
||||
<< "Non-arrays do not have a rank, shape: " << *this;
|
||||
return dimensions_.size();
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
|
||||
out << shape.ToString(/*print_layout=*/true);
|
||||
return out;
|
||||
|
@ -44,6 +44,10 @@ class Shape {
|
||||
// without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
|
||||
string ToString(bool print_layout = false) const;
|
||||
|
||||
// Returns the rank (number of dimensions) of the given shape. Shape must be
|
||||
// an array.
|
||||
int64 rank() const;
|
||||
|
||||
// The following methods mirror the protobuf generated code interface for the
|
||||
// message ShapeProto. This enabled easy migration of this data structure
|
||||
// from a proto to a proper C++ class.
|
||||
|
@ -679,8 +679,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
CHECK(LayoutUtil::IsSparseArray(shape));
|
||||
return LayoutUtil::MaxSparseElements(shape.layout()) *
|
||||
ShapeUtil::Rank(shape) * sizeof(int64);
|
||||
return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() *
|
||||
sizeof(int64);
|
||||
}
|
||||
|
||||
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
|
||||
@ -763,7 +763,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return sparse_elements_size;
|
||||
}
|
||||
int64 sparse_indices_size =
|
||||
MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape));
|
||||
MultiplyWithoutOverflow(max_sparse_elements, shape.rank());
|
||||
if (sparse_indices_size < 0) {
|
||||
return sparse_indices_size;
|
||||
}
|
||||
|
@ -298,6 +298,7 @@ class ShapeUtil {
|
||||
|
||||
// Returns the rank (number of dimensions) of the given shape.
|
||||
// Precondition: !IsTuple(shape)
|
||||
ABSL_DEPRECATED("Use `Shape::rank` instead.")
|
||||
static int64 Rank(const Shape& shape);
|
||||
|
||||
// Returns the number of dimensions for which the dimension is not (trivially)
|
||||
|
@ -79,7 +79,7 @@ void SparseIndexArray::Resize(int64 num_indices) {
|
||||
}
|
||||
|
||||
bool SparseIndexArray::Validate(const Shape& shape) const {
|
||||
if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) {
|
||||
if (rank_ == 0 || rank_ != shape.rank()) {
|
||||
return false;
|
||||
}
|
||||
int64 num_indices = index_count();
|
||||
|
@ -191,7 +191,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
|
||||
verify_output(actual, "");
|
||||
|
||||
// Try with all output layouts.
|
||||
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
|
||||
std::vector<int64> minor_to_major(expected.shape().rank());
|
||||
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
|
||||
do {
|
||||
auto layout = ShapeUtil::MakeShapeWithLayout(
|
||||
@ -234,7 +234,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
|
||||
std::vector<int64> minor_to_major(literal.shape().rank());
|
||||
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
|
||||
do {
|
||||
auto literal_relayout =
|
||||
|
@ -348,7 +348,7 @@ StatusOr<Literal> CreateLiteralForConstrainedUses(
|
||||
const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
|
||||
? use->shape()
|
||||
: use->operand(1)->shape();
|
||||
const int64 rank = ShapeUtil::Rank(indexed_shape);
|
||||
const int64 rank = indexed_shape.rank();
|
||||
if (!index_space.empty()) {
|
||||
TF_RET_CHECK(rank == index_space.size());
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
@ -459,8 +459,8 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
|
||||
std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
|
||||
CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
|
||||
CHECK_EQ(lhs->shape().rank(), 2);
|
||||
CHECK_EQ(rhs->shape().rank(), 2);
|
||||
PrecisionConfig precision_config;
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
2, PrecisionConfig::DEFAULT);
|
||||
|
Loading…
Reference in New Issue
Block a user