Legalize MinimumBroadcastShapes op.
Use it in TransformUnrankedHloPass, which allows to reduce the maximum rank for rank specialized broadcast from 6 to 5. PiperOrigin-RevId: 360415743 Change-Id: I3af377cfc49a2be33432c91d7ae06ca4009ac051
This commit is contained in:
parent
8f9e5f03d1
commit
acb619833a
@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass
|
||||
conversionTarget.addLegalDialect<
|
||||
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
||||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
|
||||
|
||||
if (broadcast_only_) {
|
||||
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
||||
|
||||
@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
}
|
||||
|
||||
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
|
||||
Value value, int targeted_rank) {
|
||||
Value shape, int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
|
||||
ChloOpTy op,
|
||||
ValueRange operands,
|
||||
ValueRange operand_shapes,
|
||||
int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value, 2> reshaped_operands;
|
||||
@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||
targeted_rank, RankedTensorType::kDynamicSize);
|
||||
|
||||
for (Value operand : operands) {
|
||||
for (auto it : llvm::zip(operands, operand_shapes)) {
|
||||
Value operand, shape;
|
||||
std::tie(operand, shape) = it;
|
||||
// Handle shape broadcasting and inference.
|
||||
Value extended_operand_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
|
||||
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of
|
||||
// elements)
|
||||
@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
ValueRange operands) {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the operands.
|
||||
// Get the minimum broadcast shapes of the operands.
|
||||
SmallVector<Value> shapes;
|
||||
shapes.reserve(operands.size());
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value greater_rank;
|
||||
for (Value operand : operands) {
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
|
||||
loc, extent_tensor_type, shapes, nullptr);
|
||||
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
|
||||
auto reduced_shapes =
|
||||
rewriter
|
||||
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
|
||||
.results();
|
||||
SmallVector<Value> reshaped_operands;
|
||||
reshaped_operands.reserve(operands.size());
|
||||
for (auto it : llvm::zip(operands, reduced_shapes)) {
|
||||
Value operand;
|
||||
Value reduced_shape;
|
||||
std::tie(operand, reduced_shape) = it;
|
||||
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operand.getType(), operand, reduced_shape);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
|
||||
// Find the largest rank of the operands.
|
||||
Value greater_rank;
|
||||
for (Value shape : reduced_shapes) {
|
||||
Value rank =
|
||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||
if (!greater_rank) {
|
||||
@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
rewriter, op, greater_rank, 1);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, 1);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
constexpr int kMaxRankSpecialization = 6;
|
||||
constexpr int kMaxRankSpecialization = 5;
|
||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
else_builder, op, greater_rank, i);
|
||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, i);
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
kMaxRankSpecialization),
|
||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||
std::to_string(kMaxRankSpecialization));
|
||||
// Add the rank 6 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, operands,
|
||||
kMaxRankSpecialization);
|
||||
// Add the rank 5 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
|
||||
reduced_shapes, kMaxRankSpecialization);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
// Return the reshaped result of the outermost if statement.
|
||||
auto result = if_op.getResult(0);
|
||||
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, result.getType(), result, broadcast_shape);
|
||||
return reshaped_result;
|
||||
}
|
||||
};
|
||||
|
||||
@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
|
||||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
shape::ShapeDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
// Setup conversion target.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect, scf::SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
StandardOpsDialect, shape::ShapeDialect,
|
||||
scf::SCFDialect, tensor::TensorDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||
|
||||
@ -199,20 +199,24 @@ func @addUnrankedUnranked(
|
||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
@ -222,12 +226,12 @@ func @addUnrankedUnranked(
|
||||
// Handle rank 2 specialization
|
||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
@ -237,12 +241,12 @@ func @addUnrankedUnranked(
|
||||
// Handle rank 3 specialization
|
||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
@ -252,47 +256,30 @@ func @addUnrankedUnranked(
|
||||
// Handle rank 4 specialization
|
||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
|
||||
// Handle rank 5 specialization
|
||||
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
|
||||
// Handle rank 6 specialization
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
@ -300,7 +287,8 @@ func @addUnrankedUnranked(
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked(
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked(
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
|
||||
@ -108,3 +108,95 @@ func @const_splat() -> tensor<3xf32> {
|
||||
%result = constant dense<4.0> : tensor<3xf32>
|
||||
return %result : tensor<3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @minimum_broadcast_shapes
|
||||
// CHECK-SAME: (%[[LHS:.*]]: memref<?xindex>, %[[RHS:.*]]: memref<?xindex>)
|
||||
func @minimum_broadcast_shapes(%lhs: tensor<?xindex>, %rhs: tensor<?xindex>) -> (tensor<?xindex>, tensor<?xindex>) {
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[RANK_LHS:.*]] = dim %[[LHS]], %[[C0]] : memref<?xindex>
|
||||
// CHECK-NEXT: %[[TRUE:.*]] = constant true
|
||||
// CHECK-NEXT: %[[C0_0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: %[[FOR_0:.*]]:2 = scf.for %[[IV:.*]] = %[[C0_0]] to %[[RANK_LHS]] step %[[C1]] iter_args(%[[ALL_ONES:.*]] = %[[TRUE]], %[[ONE_COUNT:.*]] = %[[C0_0]]) -> (i1, index) {
|
||||
// CHECK-NEXT: %[[SIZE:.*]] = load %[[LHS]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK-NEXT: %[[IS_ONE:.*]] = cmpi eq, %[[SIZE]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[NEXT_ALL_ONES:.*]] = and %[[ALL_ONES]], %[[IS_ONE]] : i1
|
||||
// CHECK-NEXT: %[[ONE_COUNT_PLUS_ONE:.*]] = addi %[[ONE_COUNT]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[NEXT_ONE_COUNT:.*]] = select %[[NEXT_ALL_ONES]], %[[ONE_COUNT_PLUS_ONE]], %[[ONE_COUNT]] : index
|
||||
// CHECK-NEXT: scf.yield %[[NEXT_ALL_ONES]], %[[NEXT_ONE_COUNT]] : i1, index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[REDUCED_RANK_LHS:.*]] = subi %[[RANK_LHS]], %[[FOR_0]]#1 : index
|
||||
// CHECK-NEXT: %[[RANK_RHS:.*]] = dim %[[RHS]], %[[C0]] : memref<?xindex>
|
||||
// CHECK: %[[REDUCED_RANK_RHS:.*]] = subi %[[RANK_RHS]], %[[LEADING_ONES:.*]]#1 : index
|
||||
// CHECK-NEXT: %[[IS_GREATER_RANK:.*]] = cmpi ugt, %[[REDUCED_RANK_RHS]], %[[REDUCED_RANK_LHS]] : index
|
||||
// CHECK-NEXT: %[[MAX_RANK:.*]] = select %[[IS_GREATER_RANK]], %[[REDUCED_RANK_RHS]], %[[REDUCED_RANK_LHS]] : index
|
||||
// CHECK-NEXT: %[[C1_1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: %[[RESULT_LHS:.*]] = alloca(%[[REDUCED_RANK_LHS]]) : memref<?xindex>
|
||||
// CHECK-NEXT: scf.for %[[IV:.*]] = %[[C0]] to %[[REDUCED_RANK_LHS]] step %[[C1_1]] {
|
||||
// CHECK-NEXT: store %[[C1_1]], %[[RESULT_LHS]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[RESULT_RHS:.*]] = alloca(%[[REDUCED_RANK_RHS]]) : memref<?xindex>
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-NEXT: %[[UPPER_BOUND:.*]] = addi %[[MAX_RANK]], %[[C2]] : index
|
||||
// CHECK-NEXT: %[[MAIN_FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C1_1]] to %[[UPPER_BOUND]] step %[[C1_1]] iter_args(%[[RUNNING_PRODUCT:.*]] = %[[C1_1]], %[[OFFSET:.*]] = %[[C0]]) -> (index, index) {
|
||||
// CHECK-NEXT: %[[FALSE:.*]] = constant false
|
||||
// CHECK-NEXT: %[[MINUS_ONE:.*]] = constant -1 : index
|
||||
// CHECK-NEXT: %[[IS_OUT_OF_BOUNDS:.*]] = cmpi ult, %[[REDUCED_RANK_LHS]], %[[IV]] : index
|
||||
// CHECK-NEXT: %[[DIMENSION:.*]] = subi %[[RANK_LHS]], %[[IV]] : index
|
||||
// CHECK-NEXT: %[[RESULT_DIMENSION:.*]] = subi %[[DIMENSION]], %[[FOR_0]]#1 : index
|
||||
// CHECK-NEXT: %[[CURRENT_SIZE:.*]] = scf.if %[[IS_OUT_OF_BOUNDS]] -> (index) {
|
||||
// CHECK-NEXT: scf.yield %[[MINUS_ONE]] : index
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[SIZE:.*]] = load %[[LHS]][%[[DIMENSION]]] : memref<?xindex>
|
||||
// CHECK-NEXT: scf.yield %[[SIZE]] : index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[IS_INITIALIZED:.*]] = cmpi ne, %[[MINUS_ONE]], %[[MINUS_ONE]] : index
|
||||
// CHECK-NEXT: %[[SAME_SIZE:.*]] = select %[[IS_INITIALIZED]], %[[MINUS_ONE]], %[[CURRENT_SIZE]] : index
|
||||
// CHECK-NEXT: %[[IS_DIFFERENT_SIZE:.*]] = cmpi ne, %[[CURRENT_SIZE]], %[[SAME_SIZE]] : index
|
||||
// CHECK-NEXT: %[[NEW_SAME_SIZE:.*]] = select %[[IS_DIFFERENT_SIZE]], %[[CURRENT_SIZE]], %[[SAME_SIZE]] : index
|
||||
// CHECK-NEXT: %[[DIFFERENT_SIZES:.*]] = or %[[FALSE]], %[[IS_DIFFERENT_SIZE]] : i1
|
||||
// CHECK-NEXT: %[[IS_ONE_OUT_OF_BOUNDS:.*]] = cmpi eq, %[[RESULT_DIMENSION]], %[[MINUS_ONE]] : index
|
||||
// CHECK-NEXT: %[[JUST_OUT_OF_BOUNDS:.*]] = or %[[FALSE]], %[[IS_ONE_OUT_OF_BOUNDS]] : i1
|
||||
// CHECK: %[[IS_INITIALIZED:.*]] = cmpi ne, %[[NEW_SAME_SIZE]], %[[MINUS_ONE]] : index
|
||||
// CHECK-NEXT: %[[SAME_SIZE:.*]] = select %[[IS_INITIALIZED]], %[[NEW_SAME_SIZE]], %[[CURRENT_SIZE_1:.*]] : index
|
||||
// CHECK-NEXT: %[[IS_DIFFERENT_SIZE:.*]] = cmpi ne, %[[CURRENT_SIZE_1]], %[[SAME_SIZE]] : index
|
||||
// CHECK-NEXT: %[[FINAL_SAME_SIZE:.*]] = select %[[IS_DIFFERENT_SIZE]], %[[CURRENT_SIZE_1]], %[[SAME_SIZE]] : index
|
||||
// CHECK: %[[FINAL_DIFFERENT_SIZES:.*]] = or %[[DIFFERENT_SIZES]], %[[IS_DIFFERENT_SIZE:.*]] : i1
|
||||
// CHECK: %[[FINAL_JUST_OUT_OF_BOUNDS:.*]] = or %[[JUST_OUT_OF_BOUNDS]], %[[IS_ONE_OUT_OF_BOUNDS:.*]] : i1
|
||||
// CHECK-NEXT: %[[STOP_COMBINING_DIMENSIONS:.*]] = or %[[FINAL_DIFFERENT_SIZES]], %[[FINAL_JUST_OUT_OF_BOUNDS]] : i1
|
||||
// CHECK-NEXT: %[[IF_STOP_COMBINING_DIMENSIONS:.*]]:2 = scf.if %[[STOP_COMBINING_DIMENSIONS]] -> (index, index) {
|
||||
// CHECK-NEXT: %[[IS_RUNNING_PRODUCT_NOT_ONE:.*]] = cmpi ne, %[[RUNNING_PRODUCT]], %[[C1_1]] : index
|
||||
// CHECK-NEXT: %[[NEW_OFFSET_1:.*]] = scf.if %[[IS_RUNNING_PRODUCT_NOT_ONE]] -> (index) {
|
||||
// CHECK-NEXT: %[[NEW_OFFSET_0:.*]] = addi %[[OFFSET]], %[[C1_1]] : index
|
||||
// CHECK-NEXT: %[[WAS_IN_BOUNDS:.*]] = cmpi sge, %[[RESULT_DIMENSION]], %[[MINUS_ONE]] : index
|
||||
// CHECK-NEXT: scf.if %[[WAS_IN_BOUNDS]] {
|
||||
// CHECK-NEXT: %[[CURRENT_DIMENSION:.*]] = subi %[[REDUCED_RANK_LHS]], %[[NEW_OFFSET_0]] : index
|
||||
// CHECK-NEXT: store %[[RUNNING_PRODUCT]], %[[RESULT_LHS]][%[[CURRENT_DIMENSION]]] : memref<?xindex>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: scf.yield %[[NEW_OFFSET_0]] : index
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: scf.yield %[[OFFSET]] : index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[IF_DIFFERENT_SIZES:.*]]:2 = scf.if %[[FINAL_DIFFERENT_SIZES]] -> (index, index) {
|
||||
// CHECK-NEXT: %[[NEW_OFFSET_2:.*]] = addi %[[NEW_OFFSET_1]], %[[C1_1]] : index
|
||||
// CHECK-NEXT: %[[IS_IN_BOUNDS:.*]] = cmpi sge, %[[RESULT_DIMENSION]], %[[C0]] : index
|
||||
// CHECK-NEXT: scf.if %[[IS_IN_BOUNDS]] {
|
||||
// CHECK-NEXT: %[[CURRENT_DIMENSION:.*]] = subi %[[REDUCED_RANK_LHS]], %[[NEW_OFFSET_2]] : index
|
||||
// CHECK-NEXT: store %[[CURRENT_SIZE]], %[[RESULT_LHS]][%[[CURRENT_DIMENSION]]] : memref<?xindex>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: scf.yield %[[C1_1]], %[[NEW_OFFSET_2]] : index, index
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: scf.yield %[[FINAL_SAME_SIZE]], %[[NEW_OFFSET_1]] : index, index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[IF_DIFFERENT_SIZES]]#0, %[[IF_DIFFERENT_SIZES]]#1 : index, index
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[NEW_RUNNING_PRODUCT:.*]] = muli %[[RUNNING_PRODUCT]], %[[FINAL_SAME_SIZE]] : index
|
||||
// CHECK-NEXT: scf.yield %[[NEW_RUNNING_PRODUCT]], %[[OFFSET]] : index, index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[IF_STOP_COMBINING_DIMENSIONS]]#0, %[[IF_STOP_COMBINING_DIMENSIONS]]#1 : index, index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: return %[[SUBVIEW_LHS:.*]], %[[SUBVIEW_RHS:.*]] : memref<?xindex>, memref<?xindex>
|
||||
%0, %1 = chlo.minimum_broadcast_shapes %lhs, %rhs :
|
||||
tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||
return %0, %1 : tensor<?xindex>, tensor<?xindex>
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ cc_library(
|
||||
hdrs = ["rewriters.h"],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
|
||||
@ -22,7 +22,9 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -86,6 +88,289 @@ class BufferizeDimOp : public OpConversionPattern<DimOp> {
|
||||
}
|
||||
};
|
||||
|
||||
class BufferizeAndConvertMinimumBroadcastShapesOp
|
||||
: public OpConversionPattern<chlo::MinimumBroadcastShapesOp> {
|
||||
public:
|
||||
using OpConversionPattern<
|
||||
chlo::MinimumBroadcastShapesOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
chlo::MinimumBroadcastShapesOp broadcast_shapes_op,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
chlo::MinimumBroadcastShapesOp::Adaptor adaptor(operands);
|
||||
auto loc = broadcast_shapes_op.getLoc();
|
||||
ImplicitLocOpBuilder lb(loc, rewriter);
|
||||
Value zero = lb.create<ConstantIndexOp>(0);
|
||||
SmallVector<Value> shapes = adaptor.shapes();
|
||||
size_t k = shapes.size();
|
||||
SmallVector<Value> ranks;
|
||||
ranks.reserve(k);
|
||||
SmallVector<Value> real_ranks;
|
||||
real_ranks.reserve(k);
|
||||
SmallVector<Value> leading_ones;
|
||||
leading_ones.reserve(k);
|
||||
|
||||
// Determine the "real" rank of each operand shape by counting leading 1's.
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
Value rank = lb.create<DimOp>(loc, shapes[i], zero);
|
||||
ranks.push_back(rank);
|
||||
leading_ones.push_back(CountLeadingOnes(lb, shapes[i], rank));
|
||||
Value real_rank = lb.create<SubIOp>(rank, leading_ones[i]);
|
||||
real_ranks.push_back(real_rank);
|
||||
}
|
||||
|
||||
// Determine the maximum real rank of the operands.
|
||||
Value max_rank = real_ranks[0];
|
||||
for (size_t i = 1; i < k; ++i) {
|
||||
Value rank_is_greater =
|
||||
lb.create<CmpIOp>(CmpIPredicate::ugt, real_ranks[i], max_rank);
|
||||
max_rank = lb.create<SelectOp>(rank_is_greater, real_ranks[i], max_rank);
|
||||
}
|
||||
|
||||
// Allocate buffers for the return values and initialize them with 1's.
|
||||
SmallVector<Value> result_shapes;
|
||||
result_shapes.reserve(k);
|
||||
auto result_type =
|
||||
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
|
||||
Value one = lb.create<ConstantIndexOp>(1);
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
// We assume the buffer will be small, so we allocate it on the stack.
|
||||
// TODO(b/181654096): Replace AllocaOp with AllocOp.
|
||||
auto result = lb.create<AllocaOp>(result_type, real_ranks[i]);
|
||||
lb.create<scf::ForOp>(zero, real_ranks[i], one, llvm::None,
|
||||
[&one, &result](OpBuilder &b, Location l, Value idx,
|
||||
ValueRange /*vr*/) {
|
||||
b.create<StoreOp>(l, one, result, idx);
|
||||
b.create<scf::YieldOp>(l, llvm::None);
|
||||
});
|
||||
result_shapes.push_back(result);
|
||||
}
|
||||
|
||||
// Iterate through the dimensions and determine which adjacent dimensions
|
||||
// can be combined. Keep a running product of the dimensions that can be
|
||||
// combined as iteration variable (initialized to 1), and the current
|
||||
// dimension offset in the result shapes. We iterate through the shapes
|
||||
// backward, because the broadcasting semantics mean that the last
|
||||
// dimensions of each shape (the least significant ones) are matched
|
||||
// together.
|
||||
Value running_product = one;
|
||||
Value current_dimension_offset = zero;
|
||||
Value two = lb.create<ConstantIndexOp>(2);
|
||||
Value max_rank_plus_two = lb.create<AddIOp>(loc, max_rank, two);
|
||||
|
||||
// Iterate from 1 to max_rank + 1 (inclusive). This iteration variable is
|
||||
// used as an offset from the end of each shape vector. We iterate until
|
||||
// max_rank + 1 to handle the case that we have a running_product > 1 left
|
||||
// when we have processed all dimensions of the largest shape.
|
||||
lb.create<scf::ForOp>(
|
||||
one, max_rank_plus_two, one,
|
||||
ValueRange{running_product, current_dimension_offset},
|
||||
[&](OpBuilder &b, Location l, Value v, ValueRange vr) {
|
||||
Value constant_false =
|
||||
b.create<ConstantOp>(l, b.getI1Type(), b.getBoolAttr(false));
|
||||
Value just_out_of_bounds = constant_false;
|
||||
Value different_sizes = constant_false;
|
||||
Value minus_one = b.create<ConstantIndexOp>(l, -1);
|
||||
|
||||
// Initialize 'same_size' to a size that we don't expect to see.
|
||||
Value same_size = minus_one;
|
||||
// 'result_dimensions' stores the current dimension with an offset of
|
||||
// 'leading_ones' to make it easier to check whether we are in-bounds
|
||||
// with respect to the "real" shape with leading 1's removed.
|
||||
SmallVector<Value> result_dimensions;
|
||||
SmallVector<Value> sizes;
|
||||
result_dimensions.reserve(k);
|
||||
sizes.reserve(k);
|
||||
|
||||
// This loop checks whether we have at least two shapes with different
|
||||
// sizes at the current dimension, and whether we just ran out of
|
||||
// bounds in at least one shape.
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
// Determine the size of the dimension. If the dimension is out of
|
||||
// bounds, we choose the value 'same_size', because then the shape
|
||||
// should not affect the check anymore whether there are two shapes
|
||||
// with different sizes at the current dimension.
|
||||
Value is_out_of_bounds =
|
||||
b.create<CmpIOp>(l, CmpIPredicate::ult, real_ranks[i], v);
|
||||
Value dimension = b.create<SubIOp>(l, ranks[i], v);
|
||||
Value result_dimension =
|
||||
b.create<SubIOp>(l, dimension, leading_ones[i]);
|
||||
result_dimensions.push_back(result_dimension);
|
||||
Value current_size =
|
||||
b.create<scf::IfOp>(
|
||||
l, TypeRange{b.getIndexType()}, is_out_of_bounds,
|
||||
[&](OpBuilder &b, Location l) {
|
||||
b.create<scf::YieldOp>(l, same_size);
|
||||
},
|
||||
[&](OpBuilder &b, Location l) {
|
||||
// Using IfOp instead of SelectOp makes sure that we
|
||||
// don't try to load if the dimension is out of bounds.
|
||||
Value size = b.create<LoadOp>(l, shapes[i], dimension);
|
||||
b.create<scf::YieldOp>(l, size);
|
||||
})
|
||||
.getResult(0);
|
||||
sizes.push_back(current_size);
|
||||
Value is_initialized =
|
||||
b.create<CmpIOp>(l, CmpIPredicate::ne, same_size, minus_one);
|
||||
same_size =
|
||||
b.create<SelectOp>(l, is_initialized, same_size, current_size);
|
||||
Value is_different_size =
|
||||
b.create<CmpIOp>(l, CmpIPredicate::ne, current_size, same_size);
|
||||
same_size = b.create<SelectOp>(l, is_different_size, current_size,
|
||||
same_size);
|
||||
different_sizes =
|
||||
b.create<OrOp>(l, different_sizes, is_different_size);
|
||||
Value is_one_out_of_bounds = b.create<CmpIOp>(
|
||||
l, CmpIPredicate::eq, result_dimension, minus_one);
|
||||
just_out_of_bounds =
|
||||
b.create<OrOp>(l, just_out_of_bounds, is_one_out_of_bounds);
|
||||
}
|
||||
Value running_product = vr.front();
|
||||
Value current_dimension_offset = vr.back();
|
||||
|
||||
// We need to stop combining dimensions if we just ran out of bounds
|
||||
// in one shape, or there are at least two shapes with different sizes
|
||||
// at the current dimension.
|
||||
Value stop_combining_dimensions =
|
||||
b.create<OrOp>(l, different_sizes, just_out_of_bounds);
|
||||
auto if_stop_combining_dimensions = b.create<scf::IfOp>(
|
||||
l, TypeRange{b.getIndexType(), b.getIndexType()},
|
||||
stop_combining_dimensions,
|
||||
[&](OpBuilder &b, Location l) {
|
||||
// If the running product is not 1, add one dimension of size
|
||||
// 'running_product' to each shape that is still indexed
|
||||
// in-bounds or has just gone out of bounds.
|
||||
Value running_product_not_one = b.create<CmpIOp>(
|
||||
l, CmpIPredicate::ne, running_product, one);
|
||||
Value new_dimension_offset =
|
||||
b.create<scf::IfOp>(
|
||||
l, TypeRange{b.getIndexType()},
|
||||
running_product_not_one,
|
||||
[&](OpBuilder &b, Location l) {
|
||||
Value new_dimension_offset = b.create<AddIOp>(
|
||||
l, current_dimension_offset, one);
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
Value was_in_bounds = b.create<CmpIOp>(
|
||||
l, CmpIPredicate::sge, result_dimensions[i],
|
||||
minus_one);
|
||||
b.create<scf::IfOp>(
|
||||
l, was_in_bounds,
|
||||
[&](OpBuilder &b, Location l) {
|
||||
Value output_dimension = b.create<SubIOp>(
|
||||
l, real_ranks[i], new_dimension_offset);
|
||||
b.create<StoreOp>(l, running_product,
|
||||
result_shapes[i],
|
||||
output_dimension);
|
||||
b.create<scf::YieldOp>(l, llvm::None);
|
||||
});
|
||||
}
|
||||
b.create<scf::YieldOp>(l, new_dimension_offset);
|
||||
},
|
||||
[&](OpBuilder &b, Location l) {
|
||||
b.create<scf::YieldOp>(l, current_dimension_offset);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// If there are at least two different sizes, copy the dimension
|
||||
// size from the input to the output shapes for all shapes that
|
||||
// are still indexed in-bounds.
|
||||
auto if_different_sizes = b.create<scf::IfOp>(
|
||||
l, TypeRange{b.getIndexType(), b.getIndexType()},
|
||||
different_sizes,
|
||||
[&](OpBuilder &b, Location l) {
|
||||
Value dimension_offset =
|
||||
b.create<AddIOp>(l, new_dimension_offset, one);
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
Value is_in_bounds = b.create<CmpIOp>(
|
||||
l, CmpIPredicate::sge, result_dimensions[i], zero);
|
||||
b.create<scf::IfOp>(
|
||||
l, is_in_bounds, [&](OpBuilder &b, Location l) {
|
||||
Value output_dimension = b.create<SubIOp>(
|
||||
l, real_ranks[i], dimension_offset);
|
||||
b.create<StoreOp>(l, sizes[i], result_shapes[i],
|
||||
output_dimension);
|
||||
b.create<scf::YieldOp>(l, llvm::None);
|
||||
});
|
||||
}
|
||||
b.create<scf::YieldOp>(l,
|
||||
ValueRange{one, dimension_offset});
|
||||
},
|
||||
[&](OpBuilder &b, Location l) {
|
||||
b.create<scf::YieldOp>(
|
||||
l, ValueRange{same_size, new_dimension_offset});
|
||||
});
|
||||
b.create<scf::YieldOp>(l, if_different_sizes.getResults());
|
||||
},
|
||||
[&](OpBuilder &b, Location l) {
|
||||
Value new_running_product =
|
||||
b.create<MulIOp>(l, running_product, same_size);
|
||||
b.create<scf::YieldOp>(l, ValueRange{new_running_product,
|
||||
current_dimension_offset});
|
||||
});
|
||||
b.create<scf::YieldOp>(l, if_stop_combining_dimensions.getResults());
|
||||
});
|
||||
for (size_t i = 0; i < k; ++i) {
|
||||
result_shapes[i] =
|
||||
RemoveLeadingOnesFrom1DMemref(lb, result_shapes[i], real_ranks[i]);
|
||||
}
|
||||
rewriter.replaceOp(broadcast_shapes_op, result_shapes);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
Value CountLeadingOnes(ImplicitLocOpBuilder &lb, Value extent_memref,
|
||||
Value rank) const {
|
||||
// Count leading 1's. Use two iteration variables for that: one with a
|
||||
// boolean flag for whether every size so far was 1, one with the number of
|
||||
// leading 1's.
|
||||
Value constant_true =
|
||||
lb.create<ConstantOp>(lb.getI1Type(), lb.getBoolAttr(true));
|
||||
Value zero = lb.create<ConstantIndexOp>(0);
|
||||
Value one = lb.create<ConstantIndexOp>(1);
|
||||
auto leading_ones_loop = lb.create<scf::ForOp>(
|
||||
zero, rank, one, ValueRange{constant_true, zero},
|
||||
[&](OpBuilder &b, Location l, Value idx, ValueRange vr) {
|
||||
auto size = b.create<LoadOp>(l, extent_memref, idx);
|
||||
auto is_equal_to_one =
|
||||
b.create<CmpIOp>(l, CmpIPredicate::eq, size, one);
|
||||
auto all_ones = b.create<AndOp>(l, vr.front(), is_equal_to_one);
|
||||
auto increased_value = b.create<AddIOp>(l, vr.back(), one);
|
||||
auto number_of_leading_ones =
|
||||
b.create<SelectOp>(l, all_ones, increased_value, vr.back());
|
||||
b.create<scf::YieldOp>(l,
|
||||
ValueRange{all_ones, number_of_leading_ones});
|
||||
});
|
||||
return leading_ones_loop.results()[1];
|
||||
}
|
||||
|
||||
Value RemoveLeadingOnesFrom1DMemref(ImplicitLocOpBuilder &lb,
|
||||
Value extent_memref, Value rank) const {
|
||||
Value leading_ones = CountLeadingOnes(lb, extent_memref, rank);
|
||||
Value new_rank = lb.create<SubIOp>(rank, leading_ones);
|
||||
auto result_type =
|
||||
MemRefType::get({ShapedType::kDynamicSize}, lb.getIndexType());
|
||||
// Ideally we would use SubView here to return a MemRef with 'leading_ones'
|
||||
// as offset, but several things related to MemRef with offsets are
|
||||
// currently broken, so instead we just allocate another buffer of the
|
||||
// desired size and copy the elements over. We assume the buffer will be
|
||||
// small, so we allocate it on the stack.
|
||||
// TODO(b/181654096): Replace AllocaOp with AllocOp.
|
||||
Value result = lb.create<AllocaOp>(result_type, new_rank);
|
||||
Value zero = lb.create<ConstantIndexOp>(0);
|
||||
Value one = lb.create<ConstantIndexOp>(1);
|
||||
lb.create<scf::ForOp>(
|
||||
zero, new_rank, one, llvm::None,
|
||||
[&](OpBuilder &b, Location l, Value idx, ValueRange /*vr*/) {
|
||||
Value idx_with_offset = b.create<AddIOp>(l, idx, leading_ones);
|
||||
auto size = b.create<LoadOp>(l, extent_memref, idx_with_offset);
|
||||
b.create<StoreOp>(l, size, result, idx);
|
||||
b.create<scf::YieldOp>(l, llvm::None);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class BufferizeRankOp : public OpConversionPattern<RankOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -102,8 +387,10 @@ class BufferizeRankOp : public OpConversionPattern<RankOp> {
|
||||
void populateExtraStdBufferizePattern(MLIRContext *context,
|
||||
BufferizeTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<BufferizeConstantOp, BufferizeDimOp, BufferizeRankOp>(
|
||||
*converter, context);
|
||||
patterns
|
||||
->insert<BufferizeConstantOp, BufferizeDimOp,
|
||||
BufferizeAndConvertMinimumBroadcastShapesOp, BufferizeRankOp>(
|
||||
*converter, context);
|
||||
}
|
||||
|
||||
} // namespace transforms
|
||||
|
||||
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Bufferize.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
@ -172,7 +173,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
|
||||
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
|
||||
tensor::FromElementsOp, tensor::CastOp, TensorLoadOp,
|
||||
tensor::FromElementsOp, tensor::CastOp,
|
||||
chlo::MinimumBroadcastShapesOp, TensorLoadOp,
|
||||
TensorToMemrefOp>();
|
||||
BufferizeTypeConverter converter;
|
||||
auto typesAreLegal = [&converter](Operation* op) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user