From 1ed59e52b11cf7e3981ab54aeb722085a07d3923 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Thu, 20 Dec 2018 16:24:50 -0800 Subject: [PATCH] Replace calls to ShapeUtil::Rank with Shape::rank. No functional change. ShapeUtil::Rank is marked as deprecated. A later CL will remove it. PiperOrigin-RevId: 226412117 --- tensorflow/compiler/aot/codegen.cc | 2 +- .../tf2xla/kernels/batchtospace_op.cc | 2 +- .../compiler/tf2xla/kernels/index_ops_cpu.cc | 4 +- .../compiler/tf2xla/kernels/mirror_pad_op.cc | 3 +- .../tf2xla/kernels/spacetobatch_op.cc | 2 +- tensorflow/compiler/tf2xla/lib/scatter.cc | 8 +- tensorflow/compiler/tf2xla/shape_util.cc | 4 +- tensorflow/compiler/tf2xla/xla_helpers.cc | 2 +- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 6 +- .../compiler/xla/client/lib/arithmetic.cc | 2 +- .../compiler/xla/client/lib/cholesky.cc | 4 +- tensorflow/compiler/xla/client/lib/matrix.cc | 10 +- tensorflow/compiler/xla/client/lib/qr.cc | 4 +- tensorflow/compiler/xla/client/lib/slicing.cc | 10 +- .../xla/client/lib/triangular_solve.cc | 8 +- .../compiler/xla/client/sharding_builder.cc | 2 +- tensorflow/compiler/xla/client/xla_builder.cc | 33 ++-- tensorflow/compiler/xla/index_util.cc | 2 +- tensorflow/compiler/xla/layout_util.cc | 14 +- tensorflow/compiler/xla/literal.cc | 46 ++--- tensorflow/compiler/xla/literal.h | 14 +- tensorflow/compiler/xla/literal_comparison.cc | 4 +- .../compiler/xla/python/numpy_bridge.cc | 4 +- tensorflow/compiler/xla/reference_util.cc | 2 +- .../xla/service/algebraic_simplifier.cc | 42 ++-- .../xla/service/algebraic_simplifier_test.cc | 4 +- .../xla/service/batchnorm_expander.cc | 8 +- .../xla/service/cpu/dot_op_emitter.cc | 2 +- .../compiler/xla/service/cpu/ir_emitter.cc | 6 +- .../service/dynamic_dimension_inference.cc | 6 +- .../xla/service/dynamic_parameter_binding.cc | 5 +- .../xla/service/elemental_ir_emitter.cc | 10 +- .../gpu/cudnn_conv_pad_for_tensor_cores.cc | 4 +- .../gpu/cudnn_conv_padding_legalization.cc | 2 +- .../compiler/xla/service/gpu/gemm_thunk.cc | 3 +- .../compiler/xla/service/gpu/gpu_fusible.cc | 6 +- .../xla/service/gpu/gpu_layout_assignment.cc | 8 +- .../xla/service/gpu/ir_emission_utils.cc | 2 +- .../xla/service/gpu/ir_emitter_unnested.cc | 15 +- .../xla/service/hlo_creation_utils.cc | 5 +- .../compiler/xla/service/hlo_evaluator.cc | 10 +- .../xla/service/hlo_evaluator_typed_visitor.h | 37 ++-- .../compiler/xla/service/hlo_instruction.cc | 4 +- .../compiler/xla/service/hlo_instructions.cc | 2 +- tensorflow/compiler/xla/service/hlo_parser.cc | 4 +- .../compiler/xla/service/hlo_sharding.cc | 4 +- .../compiler/xla/service/hlo_verifier.cc | 19 +- .../xla/service/indexed_array_analysis.cc | 8 +- .../compiler/xla/service/layout_assignment.cc | 31 ++- .../xla/service/layout_assignment_test.cc | 3 +- .../llvm_ir/dynamic_update_slice_util.cc | 2 +- .../compiler/xla/service/llvm_ir/ir_array.cc | 14 +- .../compiler/xla/service/llvm_ir/llvm_loop.cc | 2 +- .../compiler/xla/service/llvm_ir/sort_util.cc | 2 +- .../compiler/xla/service/pattern_matcher.h | 2 +- .../compiler/xla/service/scatter_expander.cc | 2 +- .../compiler/xla/service/shape_inference.cc | 185 +++++++++--------- .../compiler/xla/service/transpose_folding.cc | 2 +- tensorflow/compiler/xla/shape.cc | 6 + tensorflow/compiler/xla/shape.h | 4 + tensorflow/compiler/xla/shape_util.cc | 6 +- tensorflow/compiler/xla/shape_util.h | 1 + tensorflow/compiler/xla/sparse_index_array.cc | 2 +- .../xla/tests/client_library_test_base.cc | 4 +- tensorflow/compiler/xla/tests/test_utils.cc | 6 +- 65 files changed, 335 insertions(+), 347 deletions(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ab1c1be344e..347a365dcb9 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -129,7 +129,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; - if (xla::ShapeUtil::Rank(shape) == 0 || + if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; indices = "[0]"; diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 46e5d68c78f..6b675fa8a94 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -39,7 +39,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(crops.shape()) == 2 && + crops.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), errors::InvalidArgument("crops should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index e2c05b648bb..3e7e8eae6ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -110,8 +110,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { auto shape_status = b.GetShape(arg); OP_REQUIRES_OK(ctx, shape_status.status()); xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); - *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( - xla::ShapeUtil::Rank(arg_shape)); + *arg_shape.mutable_layout() = + xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank()); arg_shapes.push_back(std::move(arg_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index f6b8534f4d7..656f9b898f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -38,8 +38,7 @@ class MirrorPadOp : public XlaOpKernel { // - [1, 2, 3, 3, 2] in symmetric mode. int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; - for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; - --dimno) { + for (int64 dimno = original_shape.rank() - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); int64 lhs_padding = pad_literal.Get({dimno, 0}); int64 rhs_padding = pad_literal.Get({dimno, 1}); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 622efac8176..52bed2670b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -39,7 +39,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(paddings.shape()) == 2 && + paddings.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), errors::InvalidArgument("paddings should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 688056791f9..1cd5a79171d 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -48,7 +48,7 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { + if (num_index_dims > buffer_shape.rank()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -140,8 +140,8 @@ xla::StatusOr XlaScatter( ? indices_shape.dimensions_size() - 1 : indices_shape.dimensions_size()); - int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); - int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 updates_rank = updates_shape.rank(); + int64 buffer_rank = buffer_shape.rank(); int64 num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -156,7 +156,7 @@ xla::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = xla::ShapeUtil::Rank(updates_shape); + updates_rank = updates_shape.rank(); } if (updates_rank > 0) { diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index ec604af1386..2116a735d7d 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -39,7 +39,7 @@ Status PopulateInfeedLayoutVector(const xla::Shape& shape, layouts->push_back(dim); } } else { - layouts->insert(layouts->end(), xla::ShapeUtil::Rank(shape), -1); + layouts->insert(layouts->end(), shape.rank(), -1); } return Status::OK(); } @@ -55,7 +55,7 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { + for (int i = 0; i < shape.rank(); ++i) { tensor_shape->AddDim(shape.dimensions(i)); } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index c2c07512111..83ba11b3bfd 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -55,7 +55,7 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); + std::vector broadcast_dims(input_shape.rank() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 58808c76de6..58bd173e61a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -178,7 +178,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( // Converts an int32 or int64 scalar literal to an int64. static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, int64* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { @@ -194,7 +194,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, // Converts an float32 or float64 scalar literal to a float64. static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, double* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::F32) { @@ -228,7 +228,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 1) { + if (literal.shape().rank() != 1) { return errors::InvalidArgument("value is not 1D"); } int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index e86c10f030f..33ff3971d72 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -117,7 +117,7 @@ XlaOp Any(XlaOp predicates) { XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::vector all_dimensions(predicates_shape.rank()); std::iota(all_dimensions.begin(), all_dimensions.end(), 0); return Reduce(predicates, f, logical_or, all_dimensions); }); diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc index fd980499684..83b83198799 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -54,7 +54,7 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int n_dims = ShapeUtil::Rank(a_shape); + const int n_dims = a_shape.rank(); const int64 n = ShapeUtil::GetDimension(a_shape, -1); auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( @@ -144,7 +144,7 @@ XlaOp Cholesky(XlaOp a, int64 block_size, XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int ndims = ShapeUtil::Rank(a_shape); + const int ndims = a_shape.rank(); if (ndims < 2) { return InvalidArgument( "Argument to Cholesky must have rank >= 2; shape was %s", diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index 16c177b4e22..2a1e832dc26 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -41,7 +41,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); @@ -68,7 +68,7 @@ XlaOp TriangleMask(XlaOp x, int diagonal) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); @@ -99,12 +99,12 @@ XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { // Check that both tensors have the same number of dimensions. There must be // at least two (the batch dimensions can be empty). - if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { + if (x_shape.rank() != y_shape.rank()) { return InvalidArgument( "Arguments to BatchDot have different ranks: %s vs. %s", ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); } - const int ndims = ShapeUtil::Rank(x_shape); + const int ndims = x_shape.rank(); if (ndims < 2) { return InvalidArgument( "Arguments to BatchDot must have rank >= 2: got %d", ndims); @@ -169,7 +169,7 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index 72ca653173b..640412ec8bc 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -154,7 +154,7 @@ struct QRBlockResult { StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = ShapeUtil::Rank(a_shape); + const int num_dims = a_shape.rank(); if (num_dims < 2) { return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", a_shape.ToString()); @@ -325,7 +325,7 @@ StatusOr QRDecomposition( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = ShapeUtil::Rank(a_shape); + const int num_dims = a_shape.rank(); if (num_dims < 2) { return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", a_shape.ToString()); diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index f8c7df3ff51..611fffba8d0 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -26,7 +26,7 @@ XlaOp SliceInMinorDims(XlaOp x, absl::Span start, TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_RET_CHECK(n_minor_dims <= n_dims); auto major_dims = AsInt64Slice(shape.dimensions()) .subspan( @@ -55,7 +55,7 @@ XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { std::vector start_as_int32(start.begin(), start.end()); auto start_constant = ConstantR1(builder, start_as_int32); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); TF_ASSIGN_OR_RETURN(Shape start_constant_shape, builder->GetShape(start_constant)); const int64 start_length = @@ -70,7 +70,7 @@ XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -94,7 +94,7 @@ XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); auto zero = Reshape(ConstantR0(builder, 0), {1}); std::vector padded_starts(n_dims, zero); for (int i = 0; i < starts.size(); ++i) { @@ -111,7 +111,7 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_dims = shape.rank(); int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc index 4bc2f3d1218..6061e64656e 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -38,7 +38,7 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); - int ndims = ShapeUtil::Rank(shape); + int ndims = shape.rank(); int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; @@ -262,7 +262,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - int64 ndims = ShapeUtil::Rank(a_shape); + int64 ndims = a_shape.rank(); int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; @@ -356,13 +356,13 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); - if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + if (a_shape.rank() != b_shape.rank()) { return InvalidArgument( "Arguments to TriangularSolve have shapes with different ranks: " "%s vs. %s", ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = ShapeUtil::Rank(a_shape); + const int64 ndims = a_shape.rank(); if (ndims < 2) { return InvalidArgument( "Arguments to TriangularSolve was rank %d but must have rank >= 2.", diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index fb9ea6ec3fc..b9bff06cbdb 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -50,7 +50,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + CHECK_EQ(tile_shape.rank(), 1); std::vector dimensions(1, num_tiles); *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 622fc158e11..fc37a38813b 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -343,7 +343,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + operand_shape.rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); @@ -355,7 +355,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + for (int i = 0; i < operand_shape.rank(); i++) { if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape.dimensions(i)); @@ -398,8 +398,8 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); - const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); - const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + const int64 lhs_rank = lhs_shape.rank(); + const int64 rhs_rank = rhs_shape.rank(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; @@ -413,8 +413,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, for (int64 size : shape.dimensions()) { to_size.push_back(size); } - for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); - from_dim++) { + for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) { int64 to_dim = broadcast_dimensions[from_dim]; to_size[to_dim] = from_shape.dimensions(from_dim); } @@ -563,10 +562,10 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, // output, so to append dimensions on the left the instruction's dimensions // should just be the n highest dimension numbers of the output shape where // n is the number of input dimensions. - const int64 operand_rank = ShapeUtil::Rank(operand_shape); + const int64 operand_rank = operand_shape.rank(); std::vector dimensions(operand_rank); for (int i = 0; i < operand_rank; ++i) { - dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + dimensions[i] = i + shape.rank() - operand_rank; } return InDimBroadcast(shape, operand, dimensions); }); @@ -639,10 +638,10 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - std::vector starts(ShapeUtil::Rank(shape), 0); + std::vector starts(shape.rank(), 0); std::vector limits(shape.dimensions().begin(), shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); + std::vector strides(shape.rank(), 1); starts[dimno] = start_index; limits[dimno] = limit_index; strides[dimno] = stride; @@ -780,7 +779,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { + for (int i = 0; i < original_shape.rank(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape.dimensions(i)); } else { @@ -915,13 +914,13 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, Status XlaBuilder::VerifyConvolution( const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + if (lhs_shape.rank() != rhs_shape.rank()) { return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } - int num_dims = ShapeUtil::Rank(lhs_shape); + int num_dims = lhs_shape.rank(); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " @@ -1582,7 +1581,7 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - dimension = ShapeUtil::Rank(keys_shape) - 1; + dimension = keys_shape.rank() - 1; } instr.add_dimensions(dimension); std::vector operands{keys}; @@ -1652,12 +1651,12 @@ XlaOp XlaBuilder::Map(absl::Span operands, *instr.mutable_shape() = shape.ToProto(); Shape output_shape(instr.shape()); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 output_rank = output_shape.rank(); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); for (XlaOp& new_operand : new_operands) { TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); - const int64 rank = ShapeUtil::Rank(shape); + const int64 rank = shape.rank(); if (rank != output_rank) { TF_ASSIGN_OR_RETURN(new_operand, InDimBroadcast(output_shape, new_operand, {})); @@ -1866,7 +1865,7 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); + std::vector all_dimnos(operand_shape.rank()); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); return Reduce(operand, init_value, computation, all_dimnos); }); diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 2a0241af3ef..7e22a32e545 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -141,7 +141,7 @@ namespace xla { /* static */ bool IndexUtil::IndexInBounds(const Shape& shape, absl::Span index) { - int64 rank = ShapeUtil::Rank(shape); + int64 rank = shape.rank(); if (rank != index.size()) { return false; } diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index ddccd8c798d..0f07f123a87 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -211,19 +211,19 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == DENSE) { - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (layout.minor_to_major_size() != shape.rank()) { return InvalidArgument( "layout minor_to_major field contains %d elements, " "but shape is rank %d: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), + layout.minor_to_major_size(), shape.rank(), absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + std::vector dimensions_in_layout(shape.rank(), false); + for (int64 i = 0; i < shape.rank(); ++i) { int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (dim < 0 || dim >= shape.rank()) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", HumanString(layout)); @@ -376,7 +376,7 @@ Status CopyLayoutInternal(const Shape& src, Shape* dst) { } } else { if (src.has_layout()) { - if (ShapeUtil::Rank(src) != ShapeUtil::Rank(*dst)) { + if (src.rank() != dst->rank()) { return InvalidArgument("cannot copy layout from shape: ranks differs"); } TF_RETURN_IF_ERROR( @@ -410,7 +410,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } return true; } else if (ShapeUtil::IsArray(lhs)) { - return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && + return lhs.rank() == rhs.rank() && LayoutUtil::Equal(lhs.layout(), rhs.layout()); } else { // Layouts of non-array and non-tuple shapes is ignored. diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 277c98721e5..3b596078b5c 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -129,7 +129,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + new SparseIndexArray(max_sparse_elements, shape.rank())); } else { piece->set_buffer(new char[piece->size_bytes()]); } @@ -208,16 +208,15 @@ template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + TF_RET_CHECK(src_literal.shape().rank() == src_base.size()); + TF_RET_CHECK(shape().rank() == dest_base.size()); auto linear_index = [](const Shape& shape, absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { + if (src_literal.shape().rank() == 0 || shape().rank() == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); @@ -375,7 +374,7 @@ void CopyElementsBetween(absl::Span dest, if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } - std::vector index(ShapeUtil::Rank(dest_shape)); + std::vector index(dest_shape.rank()); do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; @@ -392,7 +391,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { memcpy(buffer(), src.buffer(), src.size_bytes()); } else { TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); + std::vector origin(subshape().rank(), 0); switch (subshape().element_type()) { #define COPY_ELEMENTS(XLA_T, NATIVE_T) \ case (XLA_T): \ @@ -563,7 +562,7 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(element_count(), values.bits()); CHECK_EQ(shape().element_type(), PRED); for (int64 i = 0; i < static_cast(values.bits()); ++i) { @@ -648,8 +647,7 @@ StatusOr LiteralBase::Reshape( } Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank())); } else { output = Clone(); } @@ -672,7 +670,7 @@ StatusOr LiteralBase::Reshape( Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + CHECK(IsPermutation(permutation, shape().rank())) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. @@ -711,10 +709,10 @@ template Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { Literal result_literal(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(result_shape.rank()); result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + for (int64 i = 0; i < result_shape.rank(); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); @@ -728,7 +726,7 @@ Literal LiteralBase::Slice(absl::Span start_indices, CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + for (int64 dnum = 0; dnum < shape().rank(); ++dnum) { CHECK_GE(start_indices[dnum], 0); CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) << "dnum = " << dnum; @@ -1056,7 +1054,7 @@ void SparseArrayToStringHelper(const LiteralBase& literal, pieces->push_back(ShapeToString(print_layout, subshape)); } pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); int64 num_elements = literal.sparse_element_count(); for (int64 i = 0; i < num_elements; ++i) { if (i > 0) { @@ -1079,7 +1077,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_shape, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); std::function dimensions, std::vector*)> to_string_recursive = [&](absl::Span dimensions, @@ -1433,7 +1431,7 @@ StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { template bool LiteralBase::Piece::EqualElementsInternal( const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { + if (multi_index->size() == subshape().rank()) { return (Get(*multi_index) == other.Get(*multi_index)); } for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { @@ -1722,7 +1720,7 @@ bool LiteralBase::IsR1Iota() const { return false; } - if (ShapeUtil::Rank(shape()) != 1) { + if (shape().rank() != 1) { return false; } @@ -1932,14 +1930,12 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve // the necessary space in spare_indices. - TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) - << "Scalar shapes cannot be sparse"; - TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0) << "Unexpected number of indices in proto (" << proto.sparse_indices_size() << ") for shape of rank " - << ShapeUtil::Rank(subshape()); - const int64 index_count = - proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + << subshape().rank(); + const int64 index_count = proto.sparse_indices_size() / subshape().rank(); sparse_indices()->Resize(index_count); // Copy the indices from the proto into the SparseIndexArray object. @@ -2065,7 +2061,7 @@ int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(shape().element_type(), U8); return string(absl::bit_cast(data().data()), ShapeUtil::ElementsIn(shape())); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 67e908e7ec4..67db56c2ef2 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -961,7 +961,7 @@ void MutableLiteralBase::AppendSparseElement( Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); CHECK_EQ(multi_index.size(), rank); int64 last_element = p.sparse_indices()->index_count(); CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); @@ -977,7 +977,7 @@ void LiteralBase::EachCell( if (ShapeUtil::IsZeroElementArray(shape())) { return; } - std::vector indices(ShapeUtil::Rank(shape()), 0); + std::vector indices(shape().rank(), 0); do { per_cell(indices, Get(indices)); } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); @@ -986,7 +986,7 @@ void LiteralBase::EachCell( template inline void MutableLiteralBase::PopulateR1(absl::Span values) { CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -998,7 +998,7 @@ template void MutableLiteralBase::PopulateR2( std::initializer_list> values) { CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK_EQ(shape().rank(), 2); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1024,7 +1024,7 @@ void MutableLiteralBase::PopulateFromArray(const Array& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + CHECK_EQ(shape().rank(), values.num_dimensions()); for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } @@ -1053,7 +1053,7 @@ void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, absl::Span values, bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); + int rank = shape().rank(); CHECK_EQ(indices.rank(), rank); int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); CHECK_LE(indices.max_indices(), max_elements); @@ -1077,7 +1077,7 @@ template Status MutableLiteralBase::PopulateInternal(const FnType& generator, bool parallel) { const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); + const int64 rank = this_shape.rank(); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 1ac9a48e805..91316b77f64 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -463,7 +463,7 @@ class NearComparator { } return; } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + std::vector multi_index(actual_.shape().rank(), 0); CompareLiteralsSlow(0, &multi_index); } @@ -777,7 +777,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } } } else if (ShapeUtil::IsArray(expected)) { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + if (expected.rank() != actual.rank()) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index b0aa024c747..74d2d25b4bc 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -132,7 +132,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); } } else { - int rank = ShapeUtil::Rank(shape); + int rank = shape.rank(); dimensions = PyTuple_New(rank); for (int i = 0; i < rank; ++i) { PyTuple_SET_ITEM(dimensions, i, @@ -354,7 +354,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } return tuple; } else { - int rank = ShapeUtil::Rank(literal.shape()); + int rank = literal.shape().rank(); std::vector dimensions(rank); // NOLINT - PyArray requires a long* for (int i = 0; i < rank; i++) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 92f28a9f8aa..3ba67f69e8c 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -552,7 +552,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( Literal result_literal = evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); + CHECK_EQ(result_literal.shape().rank(), 4); auto result = absl::make_unique>(result_literal.shape().dimensions(0), result_literal.shape().dimensions(1), diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 9e453203ce1..8157be1e0a6 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -251,7 +251,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { - if (ShapeUtil::Rank(hlo->shape()) == 1) { + if (hlo->shape().rank() == 1) { return hlo; } return computation_->AddInstruction(HloInstruction::CreateReshape( @@ -687,7 +687,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } PaddingConfig padding_config; - for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) { + for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); @@ -754,7 +754,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { } // If a literal is an increasing sequence from zero, replace it with an iota. - if (ShapeUtil::Rank(constant->shape()) == 1 && + if (constant->shape().rank() == 1 && ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsR1Iota()) { return ReplaceWithNewInstruction( @@ -930,9 +930,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return -1; }; - const int64 dot_rank = ShapeUtil::Rank(dot->shape()); - const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); - const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + const int64 dot_rank = dot->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); + const int64 lhs_rank = lhs->shape().rank(); const auto& dnums = dot->dot_dimension_numbers(); if (dnums.rhs_contracting_dimensions_size() > 1) { return false; @@ -1036,7 +1036,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // ) if (lhs_rank == 1 || (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { - if (ShapeUtil::Rank(rhs->shape()) == 1) { + if (rhs->shape().rank() == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), rhs), 0)))); @@ -1449,8 +1449,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot->shape().element_type() != BF16) { return Status::OK(); } - if (ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || + dot->shape().rank() > 2) { if (options_.enable_dot_strength_reduction() && !options_.is_layout_sensitive()) { TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); @@ -1732,8 +1732,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A degenerate broadcast that has the same input and output rank can be // converted into a transpose. - if (ShapeUtil::Rank(broadcast->shape()) == - ShapeUtil::Rank(operand->shape()) && + if (broadcast->shape().rank() == operand->shape().rank() && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " @@ -1888,7 +1887,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (HasInteriorPadding(pad->padding_config())) { PaddingConfig padding_config = pad->padding_config(); bool cleared_interior_padding = false; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { if (padding_config.dimensions(i).interior_padding() > 0 && pad->operand(0)->shape().dimensions(i) == 1) { cleared_interior_padding = true; @@ -2276,7 +2275,7 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { VLOG(10) << "Trying to simplify scalar slice of concat"; // Only do this for R1, there's no chance of this being useful otherwise. - if (ShapeUtil::Rank(slice->shape()) != 1) { + if (slice->shape().rank() != 1) { VLOG(10) << "Not folding, slice is not rank 1"; return false; } @@ -2326,7 +2325,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( return false; } HloInstruction* new_slice_operand = reshape->mutable_operand(0); - int64 slice_rank = ShapeUtil::Rank(slice->shape()); + int64 slice_rank = slice->shape().rank(); std::vector sliced_dims; for (int64 i = 0; i < slice_rank; ++i) { if (slice->slice_starts(i) != 0 || @@ -2338,7 +2337,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && slice->slice_starts(0) == 0) { const Shape& new_slice_shape = new_slice_operand->shape(); - const int64 rank = ShapeUtil::Rank(new_slice_shape); + const int64 rank = new_slice_shape.rank(); std::vector new_slice_starts(rank, 0); std::vector new_slice_stides(rank, 1); std::vector new_slice_limits(new_slice_shape.dimensions().begin(), @@ -2456,8 +2455,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // A Transpose feeding a reduce can simply permute the reduction dimensions // field if the output of the reduce is a vector or scalar. Higher ranked // result may require a transpose of the output. - if (ShapeUtil::Rank(reduce->shape()) <= 1 && - arg->opcode() == HloOpcode::kTranspose) { + if (reduce->shape().rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; for (auto dim : dimensions) { @@ -2516,8 +2514,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), arg->shape()); - std::vector arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); - std::vector arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); + std::vector arg_dim_in_output(arg->shape().rank(), true); + std::vector arg_dim_unmodified(arg->shape().rank(), false); for (auto dim : dimensions) { arg_dim_in_output[dim] = false; } @@ -2542,7 +2540,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } } std::vector new_reduce_dimensions; - for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { + for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { if (dimensions_not_to_reduce.count(i) == 0) { new_reduce_dimensions.push_back(i); } @@ -2779,7 +2777,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // Carry out the folding of the pad into reduce_window. VLOG(10) << "Folding pad into reduce-window."; Window new_window = window; - const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + const int64 rank = reduce_window->shape().rank(); TF_RET_CHECK(pad_config.dimensions_size() == rank); TF_RET_CHECK(window.dimensions_size() == rank); for (int64 i = 0; i < rank; ++i) { @@ -2862,7 +2860,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { // - Use this as the indices parameter of scatter, and set updates // of the scatter to be a reshaped 'values' parameter of sort (adding // 'rank' many 1 dimensions at the end). - int64 rank = ShapeUtil::Rank(operand->shape()); + int64 rank = operand->shape().rank(); Shape extended_shape = operand->shape(); extended_shape.add_dimensions(1); extended_shape.mutable_layout()->add_minor_to_major(rank); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index a9d617cbf6d..51ad748ff82 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -3498,7 +3498,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3584,7 +3584,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 0e6ca1871b3..e5f5c3edb2a 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -123,7 +123,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { auto elements_per_feature_u32 = add_instruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); - for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + for (int64 i = 0; i < operand->shape().rank(); ++i) { if (i == feature_index) { continue; } @@ -229,7 +229,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -357,7 +357,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -494,7 +494,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) { + for (int64 i = 0; i < activation_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 37cefcb2e82..1525a33af7a 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1550,7 +1550,7 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { } // Return whether the given shape is rank 2. -static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; } +static bool IsRank2(const Shape& shape) { return shape.rank() == 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index ed7fe59c80e..f29f8f8fec8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -535,7 +535,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { higher_dimensions *= normalized_keys_shape.dimensions(i); } int64 lower_dimensions = 1; - for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + for (int64 i = normalized_keys_shape.rank() - 1; i > physical_dimension_to_sort; --i) { lower_dimensions *= normalized_keys_shape.dimensions(i); } @@ -779,8 +779,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { const auto init_value = select_and_scatter->operand(2); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); // TODO(b/31410564): Implement dilation for select-and-scatter. diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 6d0472689bf..f2b7e9a1861 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -173,7 +173,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { // Find out the new dynamic dimension after reduce. int64 dimensions_not_reduced_count = 0; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + for (int i = 0; i < operand->shape().rank(); ++i) { if (dimension == i) { parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, dynamic_size); @@ -207,7 +207,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { result_dim_mapping[i] = current_result_dims++; } - for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { if (!absl::c_linear_search( dimension_numbers.lhs_contracting_dimensions(), i)) { if (operand_index == 0) { @@ -217,7 +217,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { } } - for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { if (!absl::c_linear_search( dimension_numbers.rhs_contracting_dimensions(), i) && !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc index c8bfc890506..3d0eddb0d85 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -121,10 +121,11 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { dynamic_dimension.parameter_index)); TF_RET_CHECK( dynamic_dimension.dimension < - ShapeUtil::Rank(ShapeUtil::GetSubshape( + ShapeUtil::GetSubshape( entry->parameter_instruction(dynamic_dimension.parameter_num) ->shape(), - dynamic_dimension.parameter_index))); + dynamic_dimension.parameter_index) + .rank()); return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 6f1f95f2e90..6f928fcbaab 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1327,9 +1327,9 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. - CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); + CHECK_EQ(operand_shape.rank(), hlo.shape().rank()); llvm_ir::IrArray::Index source_index(target_index.GetType()); - for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { + for (int64 i = 0; i < hlo.shape().rank(); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { @@ -1750,7 +1750,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); // Use the same index type for all tensor accesses in the same kernel. llvm::Type* index_type = index.GetType(); llvm_ir::IrArray::Index slice_start_index(index_type, rank); @@ -1893,7 +1893,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which @@ -2225,7 +2225,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( auto* iota = Cast(hlo); PrimitiveType element_type = iota->shape().element_type(); IrArray::Index elem_index = - ShapeUtil::Rank(iota->shape()) > 1 + iota->shape().rank() > 1 ? target_index.SourceIndexOfBroadcast( iota->shape(), ShapeUtil::MakeShapeWithDescendingLayout( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc index 5aa4f839f4b..958e0b9c6e7 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -50,10 +50,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr, auto* zero = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); - PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank()); bool added_padding = false; - for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + for (int64 dim = 0; dim < shape.rank(); ++dim) { if (shape.dimensions(dim) == new_shape.dimensions(dim)) { continue; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index 3a09d4d4716..17d0f7aa7bf 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -219,7 +219,7 @@ bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = - MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); + MakeNoPaddingConfig(input->shape().rank()); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 27f07b1d581..b8fbe7d2bcb 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -315,8 +315,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); - CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(output_shape_)); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank()); int64 row_dim = dim_nums.lhs_batch_dimensions_size(); int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 452e763a8ea..542af51e624 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -43,14 +43,14 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, const Layout* max_rank_layout; for (HloInstruction* param : params) { if (ShapeUtil::IsArray(param->shape()) && - ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); + param->shape().rank() > max_rank) { + max_rank = param->shape().rank(); max_rank_layout = ¶m->shape().layout(); } } return absl::c_all_of(params, [&](HloInstruction* param) { return (!ShapeUtil::IsArray(param->shape())) || - (ShapeUtil::Rank(param->shape()) < max_rank) || + (param->shape().rank() < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index f59da2caa18..ccd83e262e1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -196,9 +196,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(instruction->shape())); + instruction->shape().rank()); for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { - CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + CHECK_LT(batch_dim, instruction->shape().rank() - 2); } // Set both inputs and the output to default layout. @@ -215,11 +215,11 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kSort && - ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + instruction->operand(0)->shape().rank() > 1) { // Make sure that all the operands and the output(s) have the same layout. Shape keys_shape = instruction->operand(0)->shape(); Layout keys_layout = - LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank()); for (int64 i = 0; i < instruction->operand_count(); ++i) { Shape shape = instruction->operand(i)->shape(); *shape.mutable_layout() = keys_layout; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 33e41a2782b..5d25a032a99 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -40,7 +40,7 @@ namespace { // Return whether the given shape is rank 2 excluding the batch dimensions. bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { - return ShapeUtil::Rank(shape) == batch_dimensions_size + 2; + return shape.rank() == batch_dimensions_size + 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1472853dc44..48d9840620f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -698,8 +698,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto* source = select_and_scatter->operand(1); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, @@ -1015,7 +1015,7 @@ Status IrEmitterUnnested::EmitScatter( int64 raw_window_multidim_idx = 0; std::vector input_window_multidim; std::vector input_window_bounds; - for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) { if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_window_bounds.push_back(1); // Trivial dimension. input_window_multidim.push_back(index.GetConstantWithIndexType(0)); @@ -1027,12 +1027,11 @@ Status IrEmitterUnnested::EmitScatter( ++raw_window_multidim_idx; } } - DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + DCHECK_EQ(input_window_multidim.size(), operand->shape().rank()); // Insert a 1 dimension at the end if index_vector_dim requests one. Shape scatter_indices_shape = scatter_indices->shape(); - if (dim_numbers.index_vector_dim() == - ShapeUtil::Rank(scatter_indices_shape)) { + if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { scatter_indices_shape.add_dimensions(1); scatter_indices_shape.mutable_layout()->add_minor_to_major( dim_numbers.index_vector_dim()); @@ -3191,7 +3190,7 @@ Status AreFusedReductionOutputsConsistent( // dimensions from minor to major. DimensionVector GetDimensionsToKeepMinorToMajor( const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + DimensionVector input_dims(input_shape.rank(), 0); absl::c_iota(input_dims, 0); DimensionVector input_dims_to_keep; for (int input_dim : input_dims) { @@ -3231,7 +3230,7 @@ std::tuple GetReductionToVectorDimensions( if (input_dims_to_keep_minor_to_major.empty()) { return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + DimensionVector input_dims(input_shape.rank(), 0); absl::c_iota(input_dims, 0); absl::Span minor_to_major = LayoutUtil::MinorToMajor(input_shape); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index e41aeab19e4..1678fba1728 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -189,8 +189,7 @@ StatusOr MakeMapHlo(absl::Span operands, for (const HloInstruction* operand : operands) { CHECK_EQ(computation, operand->parent()); operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); @@ -207,7 +206,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, HloOpcode binary_opcode, HloModule* module) { DCHECK_NE(nullptr, module); - std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::vector all_dims(operand->shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 934c082bb9f..3d3c6af2f36 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -443,7 +443,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); CHECK(ShapeUtil::IsArray(reference_shape)); - const int64 rank = ShapeUtil::Rank(reference_shape); + const int64 rank = reference_shape.rank(); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); CHECK_LT(concat_dim, rank); @@ -1036,11 +1036,9 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand.shape())) + TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank()) << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand.shape()); + << " and rank of operand_to_broadcast is: " << operand.shape().rank(); // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { @@ -1251,7 +1249,7 @@ template StatusOr EvaluateSortInternal(HloInstruction* sort, const Literal& keys_literal, const Literal& values_literal) { - auto rank = ShapeUtil::Rank(keys_literal.shape()); + auto rank = keys_literal.shape().rank(); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 3ace2f54432..7e0dadaf3e6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1005,8 +1005,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + const auto lhs_rank = lhs_shape.rank(); + const auto rhs_rank = rhs_shape.rank(); CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); @@ -1175,8 +1175,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto& dnums = dot->dot_dimension_numbers(); - const int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); - const int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = lhs->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); @@ -1238,8 +1238,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto& dnums = dot->dot_dimension_numbers(); - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + const auto lhs_rank = lhs->shape().rank(); + const auto rhs_rank = rhs->shape().rank(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); @@ -1329,7 +1329,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + CHECK_EQ(pad->operand(0)->shape().rank(), pad->padding_config().dimensions_size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -1352,9 +1352,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result.shape()), 0); + std::vector input_index(evaluated_operand.shape().rank(), 0); + std::vector target_index(result.shape().rank(), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1609,7 +1608,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); int64 sort_dim = sort->dimensions(0); int64 sort_dim_elements = keys->shape().dimensions(sort_dim); - int64 rank = ShapeUtil::Rank(keys->shape()); + int64 rank = keys->shape().rank(); if (rank == 0) { // Nothing to sort. parent_->evaluated_[sort] = keys_literal.Clone(); @@ -1868,7 +1867,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - int64 rank = ShapeUtil::Rank(operand_literal.shape()); + int64 rank = operand_literal.shape().rank(); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); DimensionVector source_index(rank, 0); @@ -1980,7 +1979,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape().element_type(), window_dimension_sizes); DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + DimensionVector operand_index(operand_literal.shape().rank()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); Literal result(reduce_window->shape()); @@ -2411,7 +2410,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const int64 rank = ShapeUtil::Rank(operand->shape()); + const int64 rank = operand->shape().rank(); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); @@ -2648,12 +2647,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result = LiteralUtil::CreateR1(data); - if (ShapeUtil::Rank(iota->shape()) > 1) { + if (iota->shape().rank() > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { - TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + TF_RET_CHECK(iota->shape().rank() == 1); parent_->evaluated_[iota] = std::move(result); } @@ -2683,7 +2682,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // // This lets you calculate LI given the multidimensional indices in any order. static DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); + DimensionVector v(shape.rank()); int64 scale = 1; for (auto dim : LayoutUtil::MinorToMajor(shape)) { v[dim] = scale; @@ -2700,7 +2699,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Shape& window_shape, const Window& window, const Shape& base_shape, const absl::Span& window_count_index, const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); + const int64 rank = base_shape.rank(); DimensionVector window_index(rank); std::fill(window_index.begin(), window_index.end(), 0); do { @@ -2767,7 +2766,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& start_indices_literal) { auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result.shape()); + const auto rank = result.shape().rank(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3e8903c9537..462fe3b3215 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1039,7 +1039,7 @@ HloInstruction::CreateBroadcastSequence( const std::function)>& adder) { CHECK(ShapeUtil::IsScalar(operand->shape()) || - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + operand->shape().rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType( output_shape, operand->shape().element_type()); // Do explicit broadcast for scalar. @@ -1055,7 +1055,7 @@ HloInstruction::CreateBroadcastSequence( // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + for (int i = 0; i < operand->shape().rank(); i++) { if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 756e260b60d..977fa01acb9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -734,7 +734,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape, AppendComputation(map_computation); // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. - dimensions_.resize(ShapeUtil::Rank(shape)); + dimensions_.resize(shape.rank()); std::iota(dimensions_.begin(), dimensions_.end(), 0); } diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 44643951c14..4e9c562c93a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1980,7 +1980,7 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { - const tensorflow::int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = shape.rank(); // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); @@ -2145,7 +2145,7 @@ template bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; - tensorflow::int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = shape.rank(); *literal = Literal(shape); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 70a860c356c..b8b4fd6135d 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -30,7 +30,7 @@ HloSharding HloSharding::AssignDevice(int64 device_id) { } HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { - CHECK_EQ(1, ShapeUtil::Rank(input_shape)); + CHECK_EQ(1, input_shape.rank()); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); Array assignment(dimensions); @@ -340,7 +340,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, } // The tile assignment tensor must have the same rank as the input. - if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { + if (shape.rank() != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( "Number of tile assignment dimensions is different to the input rank. " "sharding=", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e1c737132f7..5e120b49712 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -349,7 +349,7 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { Status ShapeVerifier::HandleIota(HloInstruction* instruction) { TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); - const int64 rank = ShapeUtil::Rank(iota->shape()); + const int64 rank = iota->shape().rank(); if (rank == 0) { return InternalError("Iota does not support scalars."); } @@ -397,13 +397,11 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { const Shape& operand_shape = broadcast->operand(0)->shape(); // Check for mixed precision. TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); + TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); + for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank(); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + TF_RET_CHECK((output_dimension < broadcast->shape().rank()) && output_dimension >= 0 && (broadcast->shape().dimensions(output_dimension) == operand_shape.dimensions(operand_dimension))) @@ -524,8 +522,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { int64 max_operand_rank = 0; for (const HloInstruction* operand : map->operands()) { operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. @@ -1271,11 +1268,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(broadcast->operand(0)->shape())) + broadcast->operand(0)->shape().rank()) << "Broadcast HLO (" << broadcast->ToShortString() << ") has invalid number of dimensions: " << broadcast->dimensions().size() - << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + << " != " << broadcast->operand(0)->shape().rank(); return Status::OK(); } @@ -1376,7 +1373,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); if (LayoutUtil::IsDenseArray(operand_shape) && - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + operand_shape.rank() == result_shape.rank()) { const Layout& operand_layout = operand_shape.layout(); TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) << "Instruction shouldn't change layouts " diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 1ebb3319779..a41cf714c5e 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -1002,7 +1002,7 @@ bool CanFoldDotIntoIndexedArray( absl::Span contracting_dims, absl::Span batch_dims) { absl::optional non_contracting_non_batch_dim = - GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), + GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; @@ -1015,7 +1015,7 @@ bool CanFoldDotIntoIndexedArray( return false; } - int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape()); + int64 indexed_array_rank = indexed_array->shape().rank(); if (indexed_array->source_dim() < (indexed_array_rank - 2)) { // This restriction can be lifted by inserting reshape nodes. VLOG(3) << tag @@ -1043,7 +1043,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( return nullptr; } - int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + int64 lhs_rank = lhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); @@ -1078,7 +1078,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( return nullptr; } - int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + int64 rhs_rank = rhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_rhs_contracting_dimensions( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index b9ddd9636fe..d30d2ff9b95 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -991,8 +991,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape()) && + operand->shape().rank() == instruction->shape().rank() && !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. @@ -1012,7 +1011,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // operations. For similar reasons, if the operand and output have the same // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && - ShapeUtil::Rank(instruction->shape()) == 1) { + instruction->shape().rank() == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } @@ -1026,7 +1025,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { return absl::make_unique(operand_shape.layout()); } - if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { + if (operand_shape.rank() == output_shape.rank()) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { @@ -1045,7 +1044,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(instruction->shape()); + int64 rank = instruction->shape().rank(); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { int64 output_dim = LayoutUtil::Minor(output_layout, i); @@ -1070,7 +1069,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( ShapeUtil::IsArray(operand->shape())); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + operand->shape().rank() == user->shape().rank() && !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); @@ -1083,7 +1082,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // reshape is a bitcast when using the same layout. This may avoid copy // operations. For similar reasons, if the operand and output have the same // rank, try to match the outputs's layout to the operand. - if (ShapeUtil::Rank(operand->shape()) == 1 && + if (operand->shape().rank() == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; @@ -1098,7 +1097,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { return absl::make_unique(output_shape.layout()); } - if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { + if (operand->shape().rank() == output_shape.rank()) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { @@ -1117,7 +1116,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kTranspose) { // Pick the user layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(user->shape()); + int64 rank = user->shape().rank(); std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { @@ -1273,7 +1272,7 @@ Status LayoutAssignment::PropagateOperandConstraint( return Status::OK(); } - int64 operand_rank = ShapeUtil::Rank(operand->shape()); + int64 operand_rank = operand->shape().rank(); if (operand_rank <= 1) { return Status::OK(); } @@ -1288,7 +1287,7 @@ Status LayoutAssignment::PropagateOperandConstraint( continue; } const HloInstruction* sibling = user->operand(operand_no); - const int64 sibling_rank = ShapeUtil::Rank(sibling->shape()); + const int64 sibling_rank = sibling->shape().rank(); if (sibling_rank <= 1) { continue; } @@ -1320,13 +1319,13 @@ Status LayoutAssignment::PropagateOperandConstraint( if (ShapeUtil::IsTuple(subshape)) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } // Assign the right layout to input fusion of higher rank reduce // operations. - if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) { + if (subshape.rank() != operand->shape().rank()) { return Status::OK(); } // TODO(b/67641796): Are there cases except fusion that use this code @@ -1357,7 +1356,7 @@ Status LayoutAssignment::PropagateOperandConstraint( if (ShapeUtil::IsTuple(subshape)) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } TF_ASSIGN_OR_RETURN( @@ -1402,7 +1401,7 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( if (!instruction_can_change_layout_func_(instruction)) { // Copy the layout to the operand. if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == + operand->shape().rank() == LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( buffer_constraint.layout(), instruction, operand_no, @@ -2101,7 +2100,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( /* static */ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { if (ShapeUtil::IsArray(shape)) { - return ShapeUtil::Rank(shape) <= 1; + return shape.rank() <= 1; } return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { return IsAtMostRank1(subshape); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 31d78752f07..387b385157a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -528,8 +528,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - if (ShapeUtil::Rank(instruction->shape()) != - ShapeUtil::Rank(operand->shape())) { + if (instruction->shape().rank() != operand->shape().rank()) { continue; } TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 4d7f36d9f8b..1da77945328 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -44,7 +44,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. - const int64 rank = ShapeUtil::Rank(output_shape); + const int64 rank = output_shape.rank(); IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({b->getInt64(i)}); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 67f74231211..38078cd52b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -61,7 +61,7 @@ void IrArray::Index::Delinearize(std::vector* multidim, IrArray::Index::Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) - : multidim_(ShapeUtil::Rank(shape)), + : multidim_(shape.rank()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -120,7 +120,7 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); + DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString(); } } @@ -137,12 +137,12 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilder<>* builder) const { const auto& target_index = *this; - CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape)); + CHECK_EQ(target_index.size(), output_shape.rank()); std::vector> common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); + input_shape.rank(), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { @@ -257,7 +257,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { - int64 rank = ShapeUtil::Rank(operand_shape); + int64 rank = operand_shape.rank(); std::vector source_index(rank); for (int64 i = 0; i < rank; ++i) { source_index[i] = multidim_[dimension_mapping[i]]; @@ -271,7 +271,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( // The other dimensions can be masked out with a div and a mod operation. std::vector logical_to_physical = LayoutUtil::MakeLogicalToPhysical(shape.layout()); - int64 output_rank = ShapeUtil::Rank(shape); + int64 output_rank = shape.rank(); // The minimum physical dimension that is broadcasted. int64 min_broadcasted_dimension = output_rank; // The maximum physical dimension that is broadcasted. @@ -348,7 +348,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); + CHECK_EQ(index.size(), shape_->rank()); if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 219a9f221fb..fe320bbe727 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -235,7 +235,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { - std::vector dimensions(ShapeUtil::Rank(shape)); + std::vector dimensions(shape.rank()); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 6a9406bfeba..89b6a36f96b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -322,7 +322,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // comparisons). const Shape& keys_shape = keys_array.GetShape(); - int64 rank = ShapeUtil::Rank(keys_shape); + int64 rank = keys_shape.rank(); int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); std::vector iteration_order_to_logical_order(rank); diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index fdb6a9b01be..b9616e9132d 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -831,7 +831,7 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (ShapeUtil::Rank(*shape) != rank_) { + if (shape->rank() != rank_) { if (rank_ == 0) { EXPLAIN << "Shape is not a scalar"; } else { diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 11c2f8392d2..e8496dbd72b 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -88,7 +88,7 @@ static StatusOr CanonicalizeScatterIndices( static StatusOr PermuteScatterAndWindowDims( HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; - const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + const int64 updates_rank = updates->shape().rank(); permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8e571675c79..6d02c810766 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -96,7 +96,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, } for (const Shape* element_shape : accumulator_subshapes) { - if (ShapeUtil::Rank(*element_shape) != 0) { + if (element_shape->rank() != 0) { return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", @@ -160,10 +160,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const Window& window, PrimitiveType element_type, bool allow_negative_padding) { - if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { + if (window.dimensions_size() != base_shape.rank()) { return InvalidArgument( "Window has dimension %d but base shape has dimension %d.", - window.dimensions_size(), ShapeUtil::Rank(base_shape)); + window.dimensions_size(), base_shape.rank()); } std::vector output_dimensions(window.dimensions_size()); @@ -338,7 +338,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } - if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { + if (dimension < 0 || dimension >= arg_shapes[0]->rank()) { return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } @@ -351,12 +351,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, element_type = arg_shape->element_type(); continue; } - if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { + if (arg_shape->rank() != shape->rank()) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), - ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); + arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(), + ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( @@ -364,8 +364,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, PrimitiveType_Name(arg_shape->element_type()), PrimitiveType_Name(shape->element_type())); } - for (int64 dimension_number = 0; - dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { + for (int64 dimension_number = 0; dimension_number < arg_shape->rank(); + ++dimension_number) { if (arg_shape->dimensions(dimension_number) != shape->dimensions(dimension_number)) { if (dimension_number == dimension) { @@ -480,7 +480,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Pad operation does not support non-scalar padding values."); } - if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { + if (operand_shape.rank() != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", @@ -500,7 +500,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, padding_config.ShortDebugString()); } - std::vector dimensions(ShapeUtil::Rank(operand_shape)); + std::vector dimensions(operand_shape.rank()); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + @@ -555,9 +555,9 @@ Status ValidateDotDimensionNumbers( absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); - if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions, lhs_batch_dimensions) || - !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + !dims_in_range(rhs.rank(), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString()); @@ -583,12 +583,10 @@ Status ValidateDotDimensionNumbers( // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. const int64 lhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(lhs) - - dimension_numbers.lhs_contracting_dimensions_size() - + lhs.rank() - dimension_numbers.lhs_contracting_dimensions_size() - dimension_numbers.lhs_batch_dimensions_size(); const int64 rhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(rhs) - - dimension_numbers.rhs_contracting_dimensions_size() - + rhs.rank() - dimension_numbers.rhs_contracting_dimensions_size() - dimension_numbers.rhs_batch_dimensions_size(); if (lhs_non_contracting_non_batch_dims < 0 || lhs_non_contracting_non_batch_dims > 1 || @@ -637,7 +635,7 @@ Status ValidateDotDimensionNumbers( return fail("Element types do not match."); } - if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + if ((lhs.rank() < 1) || (rhs.rank() < 1)) { return fail("Dot only supports rank 1 or above."); } @@ -686,12 +684,12 @@ Status ValidateDotDimensionNumbers( std::unordered_set rhs_batch_dims( dimension_numbers.rhs_batch_dimensions().begin(), dimension_numbers.rhs_batch_dimensions().end()); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { + for (int64 i = 0; i < lhs.rank(); i++) { if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } - for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { + for (int64 i = 0; i < rhs.rank(); i++) { if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } @@ -708,14 +706,14 @@ Status ValidateDotDimensionNumbers( ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& lhs, const Shape& rhs) { - TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); + TF_RET_CHECK(lhs.rank() == rhs.rank()); // The shapes have to be compatible. That is, if some dimension d has a // different size in the two shapes, one of them has to be 1 (a "degenerate" // dimension). In that case, the output shape has the non-1 dimension size // from the lhs/rhs pair in every index. - std::vector output_dimensions(ShapeUtil::Rank(lhs)); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) { + std::vector output_dimensions(lhs.rank()); + for (int64 i = 0; i < lhs.rank(); ++i) { if (lhs.dimensions(i) == rhs.dimensions(i)) { output_dimensions[i] = lhs.dimensions(i); } else if (lhs.dimensions(i) == 1) { @@ -743,13 +741,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); - } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { + } else if (broadcast_dimensions.size() != smaller_shape.rank()) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %d, size of broadcast_dimensions is " "%u.", - ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); + smaller_shape.rank(), broadcast_dimensions.size()); } // broadcast_dimensions is a sequence of dimensions; its length is equal to @@ -847,8 +845,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { - std::vector identity_dims(ShapeUtil::Rank(lhs)); + if (lhs.rank() == rhs.rank()) { + std::vector identity_dims(lhs.rank()); std::iota(identity_dims.begin(), identity_dims.end(), 0); if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { @@ -865,15 +863,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { + if (lhs.rank() == rhs.rank()) { return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using // broadcast_dimensions. Scalar broadcasting is a special case of this. - const Shape& larger_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs; - const Shape& smaller_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; + const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs; + const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, @@ -1162,12 +1158,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1177,25 +1173,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-training to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1272,12 +1268,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1287,25 +1283,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-inference to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1417,41 +1413,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape)); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } - if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { + if (operand_shape.rank() != output_grad_shape.rank()) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %d, and" " rank(output_grad_shape) %d.", - ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); + operand_shape.rank(), output_grad_shape.rank()); } - if (ShapeUtil::Rank(mean_shape) != 1) { + if (mean_shape.rank() != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(mean_shape)); + mean_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } - if (ShapeUtil::Rank(var_shape) != 1) { + if (var_shape.rank() != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(var_shape)); + var_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1538,7 +1534,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Verify operand_shape and output_grad_shape have same bounds. - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (ShapeUtil::GetDimension(operand_shape, i) != ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( @@ -1603,12 +1599,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int num_dims = num_spatial_dims + 2; - if (ShapeUtil::Rank(lhs) != num_dims) { + if (lhs.rank() != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs)); } - if (ShapeUtil::Rank(rhs) != num_dims) { + if (rhs.rank() != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; rhs: %s.", num_dims, ShapeUtil::HumanString(rhs)); @@ -1853,12 +1849,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& shape, int64 split_dimension, int64 concat_dimension, int64 split_count) { TF_RET_CHECK(split_count > 0); - if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + if (split_dimension >= shape.rank() || split_dimension < 0) { return InvalidArgument( "AllToAll split_dimension %d is out-of-bounds in shape %s.", split_dimension, ShapeUtil::HumanString(shape)); } - if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + if (concat_dimension >= shape.rank() || concat_dimension < 0) { return InvalidArgument( "AllToAll concat_dimension %d is out-of-bounds in shape %s.", concat_dimension, ShapeUtil::HumanString(shape)); @@ -1932,7 +1928,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // doesn't matter which one we choose. const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { - if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { + if (dimension >= arg.rank() || dimension < 0) { return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", dimension, ShapeUtil::HumanString(arg)); } @@ -1949,7 +1945,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); std::vector new_dimensions; - for (int i = 0; i < ShapeUtil::Rank(arg); ++i) { + for (int i = 0; i < arg.rank(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); } @@ -2041,7 +2037,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( const Shape& shape, int64 dimension) { - if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", dimension); } @@ -2083,10 +2079,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, starts.size(), strides.size())); } - if (starts.size() != ShapeUtil::Rank(arg)) { + if (starts.size() != arg.rank()) { return InvalidArgument( "Slice index count does not match argument rank: %u vs %d.", - starts.size(), ShapeUtil::Rank(arg)); + starts.size(), arg.rank()); } std::vector sizes; @@ -2132,10 +2128,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(operand_shape), ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { + if (start_indices_shape.rank() != 1) { return InvalidArgument( "Dynamic slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); + start_indices_shape.rank()); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { @@ -2144,18 +2140,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { + if (operand_shape.rank() != start_num_dims) { return InvalidArgument( "Dynamic slice start number of dimensions %d (%s) must match rank " "%d of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); } - if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { + if (slice_sizes.size() != operand_shape.rank()) { return InvalidArgument( "Dynamic slice index count does not match argument rank: %u vs %d.", - slice_sizes.size(), ShapeUtil::Rank(operand_shape)); + slice_sizes.size(), operand_shape.rank()); } for (int64 dim = 0; dim < slice_sizes.size(); ++dim) { @@ -2193,10 +2189,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(start_indices_shape), ShapeUtil::HumanString(update_shape)); - if (ShapeUtil::Rank(start_indices_shape) != 1) { + if (start_indices_shape.rank() != 1) { return InvalidArgument( "Dynamic update slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); + start_indices_shape.rank()); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { @@ -2205,19 +2201,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { + if (operand_shape.rank() != start_num_dims) { return InvalidArgument( "Dynamic update slice start number of dimensions %d (%s) must match " "rank %d of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); } - if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { + if (update_shape.rank() != operand_shape.rank()) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " "%d vs %d.", - ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); + update_shape.rank(), operand_shape.rank()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, @@ -2229,7 +2225,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, PrimitiveType_Name(update_shape.element_type())); } - for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { + for (int64 dim = 0; dim < operand_shape.rank(); ++dim) { const int64 input_dim_size = operand_shape.dimensions(dim); const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { @@ -2255,7 +2251,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("a dimension number is duplicated in reverse"); } for (int64 dimension : dimensions) { - if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { + if (dimension >= operand_shape.rank() || dimension < 0) { return InvalidArgument( "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape)); @@ -2397,8 +2393,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); - const int64 operand_rank = ShapeUtil::Rank(operand_shape); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 operand_rank = operand_shape.rank(); + const int64 output_rank = output_shape.rank(); if (operand_rank > output_rank) { return InvalidArgument( "InDim style broadcast must be to an equal or higher ranked shape; " @@ -2457,9 +2453,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(inferred_shape)); } - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2475,9 +2471,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2829,7 +2825,7 @@ Status ValidateScatterDimensionNumbers( "update_window_dims in scatter op must not repeat; got: %s.", StrJoin(dim_numbers.update_window_dims(), ", ")); } - const int64 updates_rank = ShapeUtil::Rank(updates_shape); + const int64 updates_rank = updates_shape.rank(); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( @@ -2863,10 +2859,10 @@ Status ValidateScatterDimensionNumbers( // Validate window size. auto window_size = dim_numbers.update_window_dims_size() + dim_numbers.inserted_window_dims_size(); - if (window_size != ShapeUtil::Rank(operand_shape)) { + if (window_size != operand_shape.rank()) { return InvalidArgument( "Scatter op has window of size %d; doesn't match operand of rank %d.", - window_size, ShapeUtil::Rank(operand_shape)); + window_size, operand_shape.rank()); } // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. @@ -2951,10 +2947,9 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); - if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + if (updates_shape.rank() != expected_updates_rank) { return InvalidArgument("Updates tensor must be of rank %d; got %d.", - expected_updates_rank, - ShapeUtil::Rank(updates_shape)); + expected_updates_rank, updates_shape.rank()); } TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( @@ -2985,7 +2980,7 @@ Status ValidateScatterDimensionNumbers( } int64 scatter_dims_seen = 0; - for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + for (int64 i = 0; i < updates_shape.rank(); ++i) { bool is_update_window_dim = absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index eaf4f28b87c..15eb46bac0a 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -45,7 +45,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); - } else if (ShapeUtil::Rank(operand.shape()) != 2) { + } else if (operand.shape().rank() != 2) { return {}; } } diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index b206345db2a..cec8dc2c20f 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -61,6 +61,12 @@ string Shape::ToString(bool print_layout) const { } } +int64 Shape::rank() const { + CHECK(ShapeUtil::IsArray(*this)) + << "Non-arrays do not have a rank, shape: " << *this; + return dimensions_.size(); +} + std::ostream& operator<<(std::ostream& out, const Shape& shape) { out << shape.ToString(/*print_layout=*/true); return out; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7643f64d8a5..91edafe2f7c 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -44,6 +44,10 @@ class Shape { // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". string ToString(bool print_layout = false) const; + // Returns the rank (number of dimensions) of the given shape. Shape must be + // an array. + int64 rank() const; + // The following methods mirror the protobuf generated code interface for the // message ShapeProto. This enabled easy migration of this data structure // from a proto to a proper C++ class. diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index be7d71ada00..8ad241f2c9c 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -679,8 +679,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); CHECK(LayoutUtil::IsSparseArray(shape)); - return LayoutUtil::MaxSparseElements(shape.layout()) * - ShapeUtil::Rank(shape) * sizeof(int64); + return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() * + sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( @@ -763,7 +763,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return sparse_elements_size; } int64 sparse_indices_size = - MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape)); + MultiplyWithoutOverflow(max_sparse_elements, shape.rank()); if (sparse_indices_size < 0) { return sparse_indices_size; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8a7d755951e..e1c37d79e28 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -298,6 +298,7 @@ class ShapeUtil { // Returns the rank (number of dimensions) of the given shape. // Precondition: !IsTuple(shape) + ABSL_DEPRECATED("Use `Shape::rank` instead.") static int64 Rank(const Shape& shape); // Returns the number of dimensions for which the dimension is not (trivially) diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index a40bb7875e7..82091bdee65 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -79,7 +79,7 @@ void SparseIndexArray::Resize(int64 num_indices) { } bool SparseIndexArray::Validate(const Shape& shape) const { - if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + if (rank_ == 0 || rank_ != shape.rank()) { return false; } int64 num_indices = index_count(); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a3507155970..20f6c189da9 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -191,7 +191,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( verify_output(actual, ""); // Try with all output layouts. - std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); + std::vector minor_to_major(expected.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto layout = ShapeUtil::MakeShapeWithLayout( @@ -234,7 +234,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); + std::vector minor_to_major(literal.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index eafa48ed7b8..96ccda6a793 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -348,7 +348,7 @@ StatusOr CreateLiteralForConstrainedUses( const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice ? use->shape() : use->operand(1)->shape(); - const int64 rank = ShapeUtil::Rank(indexed_shape); + const int64 rank = indexed_shape.rank(); if (!index_space.empty()) { TF_RET_CHECK(rank == index_space.size()); for (int64 i = 0; i < rank; ++i) { @@ -459,8 +459,8 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + CHECK_EQ(lhs->shape().rank(), 2); + CHECK_EQ(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT);