Legalize MinimumBroadcastShapes op.

Use it in TransformUnrankedHloPass, which allows to reduce the maximum
rank for rank specialized broadcast from 6 to 5.

PiperOrigin-RevId: 360415743
Change-Id: I3af377cfc49a2be33432c91d7ae06ca4009ac051
This commit is contained in:
Adrian Kuegel 2021-03-02 06:37:48 -08:00 committed by TensorFlower Gardener
parent 8f9e5f03d1
commit acb619833a
7 changed files with 487 additions and 80 deletions

View File

@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass
conversionTarget.addLegalDialect<
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
if (broadcast_only_) {
chlo::PopulateChloBroadcastingPatterns(&getContext(),

View File

@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
}
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
Value value, int targeted_rank) {
Value shape, int targeted_rank) {
auto loc = op.getLoc();
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
ChloOpTy op,
ValueRange operands,
ValueRange operand_shapes,
int targeted_rank) {
auto loc = op.getLoc();
SmallVector<Value, 2> reshaped_operands;
@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
targeted_rank, RankedTensorType::kDynamicSize);
for (Value operand : operands) {
for (auto it : llvm::zip(operands, operand_shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
// Handle shape broadcasting and inference.
Value extended_operand_casted =
createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
// 1. Reshape operands to the given rank (with the same number of
// elements)
@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
ValueRange operands) {
auto loc = op.getLoc();
// Find the larger rank of the operands.
// Get the minimum broadcast shapes of the operands.
SmallVector<Value> shapes;
shapes.reserve(operands.size());
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value greater_rank;
for (Value operand : operands) {
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
shapes.push_back(shape);
}
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
loc, extent_tensor_type, shapes, nullptr);
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
auto reduced_shapes =
rewriter
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
.results();
SmallVector<Value> reshaped_operands;
reshaped_operands.reserve(operands.size());
for (auto it : llvm::zip(operands, reduced_shapes)) {
Value operand;
Value reduced_shape;
std::tie(operand, reduced_shape) = it;
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operand.getType(), operand, reduced_shape);
reshaped_operands.push_back(reshaped_operand);
}
// Find the largest rank of the operands.
Value greater_rank;
for (Value shape : reduced_shapes) {
Value rank =
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
if (!greater_rank) {
@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
rewriter, op, greater_rank, 1);
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, 1);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
constexpr int kMaxRankSpecialization = 6;
constexpr int kMaxRankSpecialization = 5;
for (int i = 2; i < kMaxRankSpecialization; i++) {
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
else_builder, op, greater_rank, i);
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
}
@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
kMaxRankSpecialization),
"Input for dynamic binary op lowering was of a rank greater than " +
std::to_string(kMaxRankSpecialization));
// Add the rank 6 specialization to the innermost else block.
createRankSpecializedBroadcastAndOp(else_builder, op, operands,
kMaxRankSpecialization);
// Add the rank 5 specialization to the innermost else block.
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
reduced_shapes, kMaxRankSpecialization);
// Return the result of the outermost if statement.
return if_op.getResult(0);
// Return the reshaped result of the outermost if statement.
auto result = if_op.getResult(0);
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, result.getType(), result, broadcast_shape);
return reshaped_result;
}
};
@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
shape::ShapeDialect>();
}
void runOnFunction() override {
// Setup conversion target.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
shape::ShapeDialect, scf::SCFDialect,
tensor::TensorDialect>();
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
StandardOpsDialect, shape::ShapeDialect,
scf::SCFDialect, tensor::TensorDialect>();
target.addLegalOp<FuncOp>();
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)

View File

@ -199,20 +199,24 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
@ -222,12 +226,12 @@ func @addUnrankedUnranked(
// Handle rank 2 specialization
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
@ -237,12 +241,12 @@ func @addUnrankedUnranked(
// Handle rank 3 specialization
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
@ -252,47 +256,30 @@ func @addUnrankedUnranked(
// Handle rank 4 specialization
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
// Handle rank 5 specialization
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
// Handle rank 6 specialization
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK-NEXT: }
@ -300,7 +287,8 @@ func @addUnrankedUnranked(
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK-NEXT: }
@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %c1 = constant 1 : index
@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked(
// Handle rank 1 specialization
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked(
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>

View File

@ -108,3 +108,95 @@ func @const_splat() -> tensor<3xf32> {
%result = constant dense<4.0> : tensor<3xf32>
return %result : tensor<3xf32>
}
// CHECK-LABEL: @minimum_broadcast_shapes
// CHECK-SAME: (%[[LHS:.*]]: memref<?xindex>, %[[RHS:.*]]: memref<?xindex>)
func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) -> (tensor<?xindex>, tensor<?xindex>) {
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %[[RANK_LHS:.*]] = dim %[[LHS]], %[[C0]] : memref<?xindex>
// CHECK-NEXT: %[[TRUE:.*]] = constant true
// CHECK-NEXT: %[[C0_0:.*]] = constant 0 : index
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[FOR_0:.*]]:2 = scf.for %[[IV:.*]] = %[[C0_0]] to %[[RANK_LHS]] step %[[C1]] iter_args(%[[ALL_ONES:.*]] = %[[TRUE]], %[[ONE_COUNT:.*]] = %[[C0_0]]) -> (i1, index) {
// CHECK-NEXT: %[[SIZE:.*]] = load %[[LHS]][%[[IV]]] : memref<?xindex>
// CHECK-NEXT: %[[IS_ONE:.*]] = cmpi eq, %[[SIZE]], %[[C1]] : index
// CHECK-NEXT: %[[NEXT_ALL_ONES:.*]] = and %[[ALL_ONES]], %[[IS_ONE]] : i1
// CHECK-NEXT: %[[ONE_COUNT_PLUS_ONE:.*]] = addi %[[ONE_COUNT]], %[[C1]] : index
// CHECK-NEXT: %[[NEXT_ONE_COUNT:.*]] = select %[[NEXT_ALL_ONES]], %[[ONE_COUNT_PLUS_ONE]], %[[ONE_COUNT]] : index
// CHECK-NEXT: scf.yield %[[NEXT_ALL_ONES]], %[[NEXT_ONE_COUNT]] : i1, index
// CHECK-NEXT: }
// CHECK-NEXT: %[[REDUCED_RANK_LHS:.*]] = subi %[[RANK_LHS]], %[[FOR_0]]#1 : index
// CHECK-NEXT: %[[RANK_RHS:.*]] = dim %[[RHS]], %[[C0]] : memref<?xindex>
// CHECK: %[[REDUCED_RANK_RHS:.*]] = subi %[[RANK_RHS]], %[[LEADING_ONES:.*]]#1 : index
// CHECK-NEXT: %[[IS_GREATER_RANK:.*]] = cmpi ugt, %[[REDUCED_RANK_RHS]], %[[REDUCED_RANK_LHS]] : index
// CHECK-NEXT: %[[MAX_RANK:.*]] = select %[[IS_GREATER_RANK]], %[[REDUCED_RANK_RHS]], %[[REDUCED_RANK_LHS]] : index
// CHECK-NEXT: %[[C1_1:.*]] = constant 1 : index
// CHECK-NEXT: %[[RESULT_LHS:.*]] = alloca(%[[REDUCED_RANK_LHS]]) : memref<?xindex>
// CHECK-NEXT: scf.for %[[IV:.*]] = %[[C0]] to %[[REDUCED_RANK_LHS]] step %[[C1_1]] {
// CHECK-NEXT: store %[[C1_1]], %[[RESULT_LHS]][%[[IV]]] : memref<?xindex>
// CHECK-NEXT: }
// CHECK-NEXT: %[[RESULT_RHS:.*]] = alloca(%[[REDUCED_RANK_RHS]]) : memref<?xindex>
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK-NEXT: %[[UPPER_BOUND:.*]] = addi %[[MAX_RANK]], %[[C2]] : index
// CHECK-NEXT: %[[MAIN_FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C1_1]] to %[[UPPER_BOUND]] step %[[C1_1]] iter_args(%[[RUNNING_PRODUCT:.*]] = %[[C1_1]], %[[OFFSET:.*]] = %[[C0]]) -> (index, index) {
// CHECK-NEXT: %[[FALSE:.*]] = constant false
// CHECK-NEXT: %[[MINUS_ONE:.*]] = constant -1 : index
// CHECK-NEXT: %[[IS_OUT_OF_BOUNDS:.*]] = cmpi ult, %[[REDUCED_RANK_LHS]], %[[IV]] : index
// CHECK-NEXT: %[[DIMENSION:.*]] = subi %[[RANK_LHS]], %[[IV]] : index
// CHECK-NEXT: %[[RESULT_DIMENSION:.*]] = subi %[[DIMENSION]], %[[FOR_0]]#1 : index
// CHECK-NEXT: %[[CURRENT_SIZE:.*]] = scf.if %[[IS_OUT_OF_BOUNDS]] -> (index) {
// CHECK-NEXT: scf.yield %[[MINUS_ONE]] : index
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[SIZE:.*]] = load %[[LHS]][%[[DIMENSION]]] : memref<?xindex>
// CHECK-NEXT: scf.yield %[[SIZE]] : index
// CHECK-NEXT: }
// CHECK-NEXT: %[[IS_INITIALIZED:.*]] = cmpi ne, %[[MINUS_ONE]], %[[MINUS_ONE]] : index
// CHECK-NEXT: %[[SAME_SIZE:.*]] = select %[[IS_INITIALIZED]], %[[MINUS_ONE]], %[[CURRENT_SIZE]] : index
// CHECK-NEXT: %[[IS_DIFFERENT_SIZE:.*]] = cmpi ne, %[[CURRENT_SIZE]], %[[SAME_SIZE]] : index
// CHECK-NEXT: %[[NEW_SAME_SIZE:.*]] = select %[[IS_DIFFERENT_SIZE]], %[[CURRENT_SIZE]], %[[SAME_SIZE]] : index
// CHECK-NEXT: %[[DIFFERENT_SIZES:.*]] = or %[[FALSE]], %[[IS_DIFFERENT_SIZE]] : i1
// CHECK-NEXT: %[[IS_ONE_OUT_OF_BOUNDS:.*]] = cmpi eq, %[[RESULT_DIMENSION]], %[[MINUS_ONE]] : index
// CHECK-NEXT: %[[JUST_OUT_OF_BOUNDS:.*]] = or %[[FALSE]], %[[IS_ONE_OUT_OF_BOUNDS]] : i1
// CHECK: %[[IS_INITIALIZED:.*]] = cmpi ne, %[[NEW_SAME_SIZE]], %[[MINUS_ONE]] : index
// CHECK-NEXT: %[[SAME_SIZE:.*]] = select %[[IS_INITIALIZED]], %[[NEW_SAME_SIZE]], %[[CURRENT_SIZE_1:.*]] : index
// CHECK-NEXT: %[[IS_DIFFERENT_SIZE:.*]] = cmpi ne, %[[CURRENT_SIZE_1]], %[[SAME_SIZE]] : index
// CHECK-NEXT: %[[FINAL_SAME_SIZE:.*]] = select %[[IS_DIFFERENT_SIZE]], %[[CURRENT_SIZE_1]], %[[SAME_SIZE]] : index
// CHECK: %[[FINAL_DIFFERENT_SIZES:.*]] = or %[[DIFFERENT_SIZES]], %[[IS_DIFFERENT_SIZE:.*]] : i1
// CHECK: %[[FINAL_JUST_OUT_OF_BOUNDS:.*]] = or %[[JUST_OUT_OF_BOUNDS]], %[[IS_ONE_OUT_OF_BOUNDS:.*]] : i1
// CHECK-NEXT: %[[STOP_COMBINING_DIMENSIONS:.*]] = or %[[FINAL_DIFFERENT_SIZES]], %[[FINAL_JUST_OUT_OF_BOUNDS]] : i1
// CHECK-NEXT: %[[IF_STOP_COMBINING_DIMENSIONS:.*]]:2 = scf.if %[[STOP_COMBINING_DIMENSIONS]] -> (index, index) {
// CHECK-NEXT: %[[IS_RUNNING_PRODUCT_NOT_ONE:.*]] = cmpi ne, %[[RUNNING_PRODUCT]], %[[C1_1]] : index
// CHECK-NEXT: %[[NEW_OFFSET_1:.*]] = scf.if %[[IS_RUNNING_PRODUCT_NOT_ONE]] -> (index) {
// CHECK-NEXT: %[[NEW_OFFSET_0:.*]] = addi %[[OFFSET]], %[[C1_1]] : index
// CHECK-NEXT: %[[WAS_IN_BOUNDS:.*]] = cmpi sge, %[[RESULT_DIMENSION]], %[[MINUS_ONE]] : index
// CHECK-NEXT: scf.if %[[WAS_IN_BOUNDS]] {
// CHECK-NEXT: %[[CURRENT_DIMENSION:.*]] = subi %[[REDUCED_RANK_LHS]], %[[NEW_OFFSET_0]] : index
// CHECK-NEXT: store %[[RUNNING_PRODUCT]], %[[RESULT_LHS]][%[[CURRENT_DIMENSION]]] : memref<?xindex>
// CHECK-NEXT: }
// CHECK: scf.yield %[[NEW_OFFSET_0]] : index
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %[[OFFSET]] : index
// CHECK-NEXT: }
// CHECK-NEXT: %[[IF_DIFFERENT_SIZES:.*]]:2 = scf.if %[[FINAL_DIFFERENT_SIZES]] -> (index, index) {
// CHECK-NEXT: %[[NEW_OFFSET_2:.*]] = addi %[[NEW_OFFSET_1]], %[[C1_1]] : index
// CHECK-NEXT: %[[IS_IN_BOUNDS:.*]] = cmpi sge, %[[RESULT_DIMENSION]], %[[C0]] : index
// CHECK-NEXT: scf.if %[[IS_IN_BOUNDS]] {
// CHECK-NEXT: %[[CURRENT_DIMENSION:.*]] = subi %[[REDUCED_RANK_LHS]], %[[NEW_OFFSET_2]] : index
// CHECK-NEXT: store %[[CURRENT_SIZE]], %[[RESULT_LHS]][%[[CURRENT_DIMENSION]]] : memref<?xindex>
// CHECK-NEXT: }
// CHECK: scf.yield %[[C1_1]], %[[NEW_OFFSET_2]] : index, index
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %[[FINAL_SAME_SIZE]], %[[NEW_OFFSET_1]] : index, index
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[IF_DIFFERENT_SIZES]]#0, %[[IF_DIFFERENT_SIZES]]#1 : index, index
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[NEW_RUNNING_PRODUCT:.*]] = muli %[[RUNNING_PRODUCT]], %[[FINAL_SAME_SIZE]] : index
// CHECK-NEXT: scf.yield %[[NEW_RUNNING_PRODUCT]], %[[OFFSET]] : index, index
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[IF_STOP_COMBINING_DIMENSIONS]]#0, %[[IF_STOP_COMBINING_DIMENSIONS]]#1 : index, index
// CHECK-NEXT: }
// CHECK: return %[[SUBVIEW_LHS:.*]], %[[SUBVIEW_RHS:.*]] : memref<?xindex>, memref<?xindex>
%0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
return %0, %1 : tensor<?xindex>, tensor<?xindex>
}

View File

@ -37,6 +37,7 @@ cc_library(
hdrs = ["rewriters.h"],
compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/hlo",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",

View File

@ -22,7 +22,9 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
namespace mlir {
@ -86,6 +88,289 @@ class BufferizeDimOp : public OpConversionPattern<DimOp> {
}
};
class BufferizeAndConvertMinimumBroadcastShapesOp
: public OpConversionPattern<chlo::MinimumBroadcastShapesOp> {
public:
using OpConversionPattern<
chlo::MinimumBroadcastShapesOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
chlo::MinimumBroadcastShapesOp broadcast_shapes_op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
chlo::MinimumBroadcastShapesOp::Adaptor adaptor(operands);
auto loc = broadcast_shapes_op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
Value zero = lb.create<ConstantIndexOp>(0);
SmallVector<Value> shapes = adaptor.shapes();
size_t k = shapes.size();
SmallVector<Value> ranks;
ranks.reserve(k);
SmallVector<Value> real_ranks;
real_ranks.reserve(k);
SmallVector<Value> leading_ones;
leading_ones.reserve(k);
// Determine the "real" rank of each operand shape by counting leading 1's.
for (size_t i = 0; i < k; ++i) {
Value rank = lb.create<DimOp>(loc, shapes[i], zero);
ranks.push_back(rank);
leading_ones.push_back(CountLeadingOnes(lb, shapes[i], rank));
Value real_rank = lb.create<SubIOp>(rank, leading_ones[i]);
real_ranks.push_back(real_rank);
}
// Determine the maximum real rank of the operands.
Value max_rank = real_ranks[0];
for (size_t i = 1; i < k; ++i) {
Value rank_is_greater =
lb.create<CmpIOp>(CmpIPredicate::ugt, real_ranks[i], max_rank);
max_rank = lb.create<SelectOp>(rank_is_greater, real_ranks[i], max_rank);
}
// Allocate buffers for the return values and initialize them with 1's.
SmallVector<Value> result_shapes;
result_shapes.reserve(k);
auto result_type =
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
Value one = lb.create<ConstantIndexOp>(1);
for (size_t i = 0; i < k; ++i) {
// We assume the buffer will be small, so we allocate it on the stack.
// TODO(b/181654096): Replace AllocaOp with AllocOp.
auto result = lb.create<AllocaOp>(result_type, real_ranks[i]);
lb.create<scf::ForOp>(zero, real_ranks[i], one, llvm::None,
[&one, &result](OpBuilder &b, Location l, Value idx,
ValueRange /*vr*/) {
b.create<StoreOp>(l, one, result, idx);
b.create<scf::YieldOp>(l, llvm::None);
});
result_shapes.push_back(result);
}
// Iterate through the dimensions and determine which adjacent dimensions
// can be combined. Keep a running product of the dimensions that can be
// combined as iteration variable (initialized to 1), and the current
// dimension offset in the result shapes. We iterate through the shapes
// backward, because the broadcasting semantics mean that the last
// dimensions of each shape (the least significant ones) are matched
// together.
Value running_product = one;
Value current_dimension_offset = zero;
Value two = lb.create<ConstantIndexOp>(2);
Value max_rank_plus_two = lb.create<AddIOp>(loc, max_rank, two);
// Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is
// used as an offset from the end of each shape vector. We iterate until
// max_rank + 1 to handle the case that we have a running_product > 1 left
// when we have processed all dimensions of the largest shape.
lb.create<scf::ForOp>(
one, max_rank_plus_two, one,
ValueRange{running_product, current_dimension_offset},
[&](OpBuilder &b, Location l, Value v, ValueRange vr) {
Value constant_false =
b.create<ConstantOp>(l, b.getI1Type(), b.getBoolAttr(false));
Value just_out_of_bounds = constant_false;
Value different_sizes = constant_false;
Value minus_one = b.create<ConstantIndexOp>(l, -1);
// Initialize 'same_size' to a size that we don't expect to see.
Value same_size = minus_one;
// 'result_dimensions' stores the current dimension with an offset of
// 'leading_ones' to make it easier to check whether we are in-bounds
// with respect to the "real" shape with leading 1's removed.
SmallVector<Value> result_dimensions;
SmallVector<Value> sizes;
result_dimensions.reserve(k);
sizes.reserve(k);
// This loop checks whether we have at least two shapes with different
// sizes at the current dimension, and whether we just ran out of
// bounds in at least one shape.
for (size_t i = 0; i < k; ++i) {
// Determine the size of the dimension. If the dimension is out of
// bounds, we choose the value 'same_size', because then the shape
// should not affect the check anymore whether there are two shapes
// with different sizes at the current dimension.
Value is_out_of_bounds =
b.create<CmpIOp>(l, CmpIPredicate::ult, real_ranks[i], v);
Value dimension = b.create<SubIOp>(l, ranks[i], v);
Value result_dimension =
b.create<SubIOp>(l, dimension, leading_ones[i]);
result_dimensions.push_back(result_dimension);
Value current_size =
b.create<scf::IfOp>(
l, TypeRange{b.getIndexType()}, is_out_of_bounds,
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, same_size);
},
[&](OpBuilder &b, Location l) {
// Using IfOp instead of SelectOp makes sure that we
// don't try to load if the dimension is out of bounds.
Value size = b.create<LoadOp>(l, shapes[i], dimension);
b.create<scf::YieldOp>(l, size);
})
.getResult(0);
sizes.push_back(current_size);
Value is_initialized =
b.create<CmpIOp>(l, CmpIPredicate::ne, same_size, minus_one);
same_size =
b.create<SelectOp>(l, is_initialized, same_size, current_size);
Value is_different_size =
b.create<CmpIOp>(l, CmpIPredicate::ne, current_size, same_size);
same_size = b.create<SelectOp>(l, is_different_size, current_size,
same_size);
different_sizes =
b.create<OrOp>(l, different_sizes, is_different_size);
Value is_one_out_of_bounds = b.create<CmpIOp>(
l, CmpIPredicate::eq, result_dimension, minus_one);
just_out_of_bounds =
b.create<OrOp>(l, just_out_of_bounds, is_one_out_of_bounds);
}
Value running_product = vr.front();
Value current_dimension_offset = vr.back();
// We need to stop combining dimensions if we just ran out of bounds
// in one shape, or there are at least two shapes with different sizes
// at the current dimension.
Value stop_combining_dimensions =
b.create<OrOp>(l, different_sizes, just_out_of_bounds);
auto if_stop_combining_dimensions = b.create<scf::IfOp>(
l, TypeRange{b.getIndexType(), b.getIndexType()},
stop_combining_dimensions,
[&](OpBuilder &b, Location l) {
// If the running product is not 1, add one dimension of size
// 'running_product' to each shape that is still indexed
// in-bounds or has just gone out of bounds.
Value running_product_not_one = b.create<CmpIOp>(
l, CmpIPredicate::ne, running_product, one);
Value new_dimension_offset =
b.create<scf::IfOp>(
l, TypeRange{b.getIndexType()},
running_product_not_one,
[&](OpBuilder &b, Location l) {
Value new_dimension_offset = b.create<AddIOp>(
l, current_dimension_offset, one);
for (size_t i = 0; i < k; ++i) {
Value was_in_bounds = b.create<CmpIOp>(
l, CmpIPredicate::sge, result_dimensions[i],
minus_one);
b.create<scf::IfOp>(
l, was_in_bounds,
[&](OpBuilder &b, Location l) {
Value output_dimension = b.create<SubIOp>(
l, real_ranks[i], new_dimension_offset);
b.create<StoreOp>(l, running_product,
result_shapes[i],
output_dimension);
b.create<scf::YieldOp>(l, llvm::None);
});
}
b.create<scf::YieldOp>(l, new_dimension_offset);
},
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, current_dimension_offset);
})
.getResult(0);
// If there are at least two different sizes, copy the dimension
// size from the input to the output shapes for all shapes that
// are still indexed in-bounds.
auto if_different_sizes = b.create<scf::IfOp>(
l, TypeRange{b.getIndexType(), b.getIndexType()},
different_sizes,
[&](OpBuilder &b, Location l) {
Value dimension_offset =
b.create<AddIOp>(l, new_dimension_offset, one);
for (size_t i = 0; i < k; ++i) {
Value is_in_bounds = b.create<CmpIOp>(
l, CmpIPredicate::sge, result_dimensions[i], zero);
b.create<scf::IfOp>(
l, is_in_bounds, [&](OpBuilder &b, Location l) {
Value output_dimension = b.create<SubIOp>(
l, real_ranks[i], dimension_offset);
b.create<StoreOp>(l, sizes[i], result_shapes[i],
output_dimension);
b.create<scf::YieldOp>(l, llvm::None);
});
}
b.create<scf::YieldOp>(l,
ValueRange{one, dimension_offset});
},
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(
l, ValueRange{same_size, new_dimension_offset});
});
b.create<scf::YieldOp>(l, if_different_sizes.getResults());
},
[&](OpBuilder &b, Location l) {
Value new_running_product =
b.create<MulIOp>(l, running_product, same_size);
b.create<scf::YieldOp>(l, ValueRange{new_running_product,
current_dimension_offset});
});
b.create<scf::YieldOp>(l, if_stop_combining_dimensions.getResults());
});
for (size_t i = 0; i < k; ++i) {
result_shapes[i] =
RemoveLeadingOnesFrom1DMemref(lb, result_shapes[i], real_ranks[i]);
}
rewriter.replaceOp(broadcast_shapes_op, result_shapes);
return success();
}
private:
Value CountLeadingOnes(ImplicitLocOpBuilder &lb, Value extent_memref,
Value rank) const {
// Count leading 1's. Use two iteration variables for that: one with a
// boolean flag for whether every size so far was 1, one with the number of
// leading 1's.
Value constant_true =
lb.create<ConstantOp>(lb.getI1Type(), lb.getBoolAttr(true));
Value zero = lb.create<ConstantIndexOp>(0);
Value one = lb.create<ConstantIndexOp>(1);
auto leading_ones_loop = lb.create<scf::ForOp>(
zero, rank, one, ValueRange{constant_true, zero},
[&](OpBuilder &b, Location l, Value idx, ValueRange vr) {
auto size = b.create<LoadOp>(l, extent_memref, idx);
auto is_equal_to_one =
b.create<CmpIOp>(l, CmpIPredicate::eq, size, one);
auto all_ones = b.create<AndOp>(l, vr.front(), is_equal_to_one);
auto increased_value = b.create<AddIOp>(l, vr.back(), one);
auto number_of_leading_ones =
b.create<SelectOp>(l, all_ones, increased_value, vr.back());
b.create<scf::YieldOp>(l,
ValueRange{all_ones, number_of_leading_ones});
});
return leading_ones_loop.results()[1];
}
Value RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb,
Value extent_memref, Value rank) const {
Value leading_ones = CountLeadingOnes(lb, extent_memref, rank);
Value new_rank = lb.create<SubIOp>(rank, leading_ones);
auto result_type =
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
// Ideally we would use SubView here to return a MemRef with 'leading_ones'
// as offset, but several things related to MemRef with offsets are
// currently broken, so instead we just allocate another buffer of the
// desired size and copy the elements over. We assume the buffer will be
// small, so we allocate it on the stack.
// TODO(b/181654096): Replace AllocaOp with AllocOp.
Value result = lb.create<AllocaOp>(result_type, new_rank);
Value zero = lb.create<ConstantIndexOp>(0);
Value one = lb.create<ConstantIndexOp>(1);
lb.create<scf::ForOp>(
zero, new_rank, one, llvm::None,
[&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) {
Value idx_with_offset = b.create<AddIOp>(l, idx, leading_ones);
auto size = b.create<LoadOp>(l, extent_memref, idx_with_offset);
b.create<StoreOp>(l, size, result, idx);
b.create<scf::YieldOp>(l, llvm::None);
});
return result;
}
};
class BufferizeRankOp : public OpConversionPattern<RankOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -102,8 +387,10 @@ class BufferizeRankOp : public OpConversionPattern<RankOp> {
void populateExtraStdBufferizePattern(MLIRContext *context,
BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns) {
patterns->insert<BufferizeConstantOp, BufferizeDimOp, BufferizeRankOp>(
*converter, context);
patterns
->insert<BufferizeConstantOp, BufferizeDimOp,
BufferizeAndConvertMinimumBroadcastShapesOp, BufferizeRankOp>(
*converter, context);
}
} // namespace transforms

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
@ -172,7 +173,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
target.addIllegalDialect<mhlo::MhloDialect>();
target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::CastOp, TensorLoadOp,
tensor::FromElementsOp, tensor::CastOp,
chlo::MinimumBroadcastShapesOp, TensorLoadOp,
TensorToMemrefOp>();
BufferizeTypeConverter converter;
auto typesAreLegal = [&converter](Operation* op) {