[XLA][MLIR] Insert DynamicMemRefCastOp when lowering HLO DynamicBroadcastInDim.

In order to support expansion of size-1 dims in dynamic case, we insert
DynamicMemRefCastOp to cast the argument to the "type-erased" memref with
0-stride and size of the target dimension if expansion is necessary.

Example

```mlir
%outDimX = constant 3 : index
%outDimY = constant 4 : index
%outShape = "xla_hlo.scalars_to_dimension_tensor"(%outDimX, %outDimY)
     : (index, index) -> tensor<2xindex>
%out = "xla_hlo.dynamic_broadcast_in_dim"(%src, %outShape){
  broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>

// *** IR Dump After HloLegalizeToLhlo ***

// Constants
%outDimX = constant 3 : index
%outDimY = constant 4 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%true = constant 1 : i1

// Compute if expansion is needed for dim X.
%srcDimX = dim %src, 0 : memref<?x?xf32>
%isExpansionX = cmpi "eq", %srcDimX, %c1 : index
// Set stride to 0 if expansion is needed.
%strideX = select %isExpansionX, %c0, %c1 : index

// Compute if expansion is needed for dim Y.
%srcDimY = dim %src, 1 : memref<?x?xf32>
%isExpansionY = cmpi "eq", %srcDimY, %c1 : index
// Set stride to 0 if expansion is needed
%strideY = select %isExpansionY, %c0, %c1 : index

#map = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
%src_transformed = xla_lhlo.dynamic_memref_cast
    %src(%outDimX, %outDimY)[%strideX, %strideY]
    : memref<?x?xf32>->memref<?x?xf32, #map>
"xla_lhlo.broadcast_in_dim"(%src_transformed, %out) {
  broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (memref<?x?xf32, #map>, memref<?x?xf32>) -> ()

```

PiperOrigin-RevId: 314337158
Change-Id: I2d20dab9e96583f14032f585c6d2752b0dfb7f08
This commit is contained in:
Alexander Belyaev 2020-06-02 08:41:31 -07:00 committed by TensorFlower Gardener
parent cf744bfeb1
commit fa75abd1ad
2 changed files with 102 additions and 7 deletions

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -157,13 +157,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
func @external_func() -> tensor<3xi64>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape)
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
: (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = call @external_func()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
@ -175,7 +178,33 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %[[RESULT]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
return

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -153,13 +154,78 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(loc, operands[0], resultBuffer,
op.broadcast_dimensions());
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer});
return success();
}
private:
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
SmallVector<Value, 2> sizes, strides;
sizes.reserve(operand_shape.size());
strides.reserve(operand_shape.size());
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
Value broadcast_dim_value =
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
Value result_dim_size = b->create<ExtractElementOp>(
loc, op.output_dimensions(), broadcast_dim_value);
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[dim.index()])
? b->create<DimOp>(loc, operand, dim.index()).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
.getResult();
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
// tensor<index> for `output_dimensions` as well.
if (!result_dim_size.getType().isIndex()) {
result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
// 2) Operand dim < result dim => expansion is needed => stride := 0.
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
operand_dim_size, result_dim_size);
strides.push_back(
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
// Size of input dim can be set to the size of the corresponding output
// dimension for both cases.
sizes.push_back(result_dim_size);
}
// Type-erased memref type with static rank, dynamic sizes and strides.
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
MemRefType::kDynamicStrideOrOffset);
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
MemRefType::kDynamicSize);
auto type_erased_memref_type = MemRefType::get(
dynamic_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
return transformed_operand;
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {