[XLA][MLIR] Insert DynamicMemRefCastOp when lowering HLO DynamicBroadcastInDim.
In order to support expansion of size-1 dims in dynamic case, we insert DynamicMemRefCastOp to cast the argument to the "type-erased" memref with 0-stride and size of the target dimension if expansion is necessary. Example ```mlir %outDimX = constant 3 : index %outDimY = constant 4 : index %outShape = "xla_hlo.scalars_to_dimension_tensor"(%outDimX, %outDimY) : (index, index) -> tensor<2xindex> %out = "xla_hlo.dynamic_broadcast_in_dim"(%src, %outShape){ broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // *** IR Dump After HloLegalizeToLhlo *** // Constants %outDimX = constant 3 : index %outDimY = constant 4 : index %c0 = constant 0 : index %c1 = constant 1 : index %true = constant 1 : i1 // Compute if expansion is needed for dim X. %srcDimX = dim %src, 0 : memref<?x?xf32> %isExpansionX = cmpi "eq", %srcDimX, %c1 : index // Set stride to 0 if expansion is needed. %strideX = select %isExpansionX, %c0, %c1 : index // Compute if expansion is needed for dim Y. %srcDimY = dim %src, 1 : memref<?x?xf32> %isExpansionY = cmpi "eq", %srcDimY, %c1 : index // Set stride to 0 if expansion is needed %strideY = select %isExpansionY, %c0, %c1 : index #map = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> %src_transformed = xla_lhlo.dynamic_memref_cast %src(%outDimX, %outDimY)[%strideX, %strideY] : memref<?x?xf32>->memref<?x?xf32, #map> "xla_lhlo.broadcast_in_dim"(%src_transformed, %out) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (memref<?x?xf32, #map>, memref<?x?xf32>) -> () ``` PiperOrigin-RevId: 314337158 Change-Id: I2d20dab9e96583f14032f585c6d2752b0dfb7f08
This commit is contained in:
parent
cf744bfeb1
commit
fa75abd1ad
|
@ -1,4 +1,4 @@
|
|||
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
|
||||
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
@ -157,13 +157,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
|||
|
||||
func @external_func() -> tensor<3xi64>
|
||||
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
|
||||
// CHECK-LABEL: func @dyn_broadcast
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape)
|
||||
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
: (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = call @external_func()
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
|
@ -175,7 +178,33 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
|||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
// CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %[[RESULT]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
|
||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
|
||||
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
|
||||
|
||||
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
// Do not store the value back to avoid the tensor-store being rewritten to
|
||||
// a copy into the pre-allocated argument.
|
||||
return
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
|
@ -153,13 +154,78 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
|||
auto loc = op.getLoc();
|
||||
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
||||
loc, op.getResult(), op.output_dimensions(), &rewriter);
|
||||
rewriter.create<xla_lhlo::BroadcastInDimOp>(loc, operands[0], resultBuffer,
|
||||
op.broadcast_dimensions());
|
||||
|
||||
Value transformed_operand =
|
||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
rewriter.create<xla_lhlo::BroadcastInDimOp>(
|
||||
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
||||
|
||||
rewriter.replaceOp(op, {resultBuffer});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
|
||||
SmallVector<Value, 2> sizes, strides;
|
||||
sizes.reserve(operand_shape.size());
|
||||
strides.reserve(operand_shape.size());
|
||||
|
||||
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||
Value broadcast_dim_value =
|
||||
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
|
||||
Value result_dim_size = b->create<ExtractElementOp>(
|
||||
loc, op.output_dimensions(), broadcast_dim_value);
|
||||
Value operand_dim_size =
|
||||
ShapedType::isDynamic(operand_shape[dim.index()])
|
||||
? b->create<DimOp>(loc, operand, dim.index()).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
|
||||
.getResult();
|
||||
|
||||
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
|
||||
// tensor<index> for `output_dimensions` as well.
|
||||
if (!result_dim_size.getType().isIndex()) {
|
||||
result_dim_size =
|
||||
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
|
||||
}
|
||||
|
||||
// There can be two cases:
|
||||
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
|
||||
// 2) Operand dim < result dim => expansion is needed => stride := 0.
|
||||
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
|
||||
operand_dim_size, result_dim_size);
|
||||
strides.push_back(
|
||||
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
|
||||
|
||||
// Size of input dim can be set to the size of the corresponding output
|
||||
// dimension for both cases.
|
||||
sizes.push_back(result_dim_size);
|
||||
}
|
||||
|
||||
// Type-erased memref type with static rank, dynamic sizes and strides.
|
||||
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
|
||||
MemRefType::kDynamicStrideOrOffset);
|
||||
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
|
||||
MemRefType::kDynamicSize);
|
||||
auto type_erased_memref_type = MemRefType::get(
|
||||
dynamic_shape, operand_type.getElementType(),
|
||||
makeStridedLinearLayoutMap(dynamic_layout,
|
||||
/*offset=*/0, b->getContext()));
|
||||
|
||||
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
|
||||
loc, type_erased_memref_type, operand, sizes, strides);
|
||||
return transformed_operand;
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
|
||||
|
|
Loading…
Reference in New Issue