diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index f4b9fa206f2..38ea818aea8 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure +// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -157,13 +157,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { func @external_func() -> tensor<3xi64> +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> + // CHECK-LABEL: func @dyn_broadcast func @dyn_broadcast(%operand: memref) { + // CHECK-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref %shape = call @external_func() : () -> tensor<3xi64> - %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) - {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} - : (tensor, tensor<3xi64>) -> tensor + %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { + broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + } : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[SHAPE:.*]] = call @external_func() // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> @@ -175,7 +178,33 @@ func @dyn_broadcast(%operand: memref) { // CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> // CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index // CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) - // CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %[[RESULT]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + + // CHECK: %[[C0_:.*]] = constant 0 : index + // CHECK: %[[C1_:.*]] = constant 1 : index + + // CHECK: %[[C1__:.*]] = constant 1 : index + // CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> + // CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref + // CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index + // CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] + // CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index + + // CHECK: %[[C2_:.*]] = constant 2 : index + // CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> + // CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref + // CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index + // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] + // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index + + // CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast + // CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) + // CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] + // CHECK-SAME: : memref -> memref + + // CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { + // CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME: } : (memref, memref) -> () + // Do not store the value back to avoid the tensor-store being rewritten to // a copy into the pre-allocated argument. return diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index df92681cd97..6f5bafef4c0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -153,13 +154,78 @@ struct HloToLhloDynamicBroadcastInDimOpConverter auto loc = op.getLoc(); Value resultBuffer = InsertDynamicAllocAndDealloc( loc, op.getResult(), op.output_dimensions(), &rewriter); - rewriter.create(loc, operands[0], resultBuffer, - op.broadcast_dimensions()); + + Value transformed_operand = + InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); + rewriter.create( + loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); rewriter.replaceOp(op, {resultBuffer}); return success(); } + + private: + // Inserts dynamic memref to change the layout of the memref to put 0-stride + // and size of the target dimension if size-1 dimension expansion is + // necessary. + xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( + xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { + auto loc = op.getLoc(); + auto operand_type = operand.getType().cast(); + auto operand_shape = operand_type.getShape(); + + SmallVector sizes, strides; + sizes.reserve(operand_shape.size()); + strides.reserve(operand_shape.size()); + + Value zero = b->create(loc, 0); + Value one = b->create(loc, 1); + for (auto dim : llvm::enumerate(op.broadcast_dimensions())) { + Value broadcast_dim_value = + b->create(loc, dim.value().getSExtValue()); + Value result_dim_size = b->create( + loc, op.output_dimensions(), broadcast_dim_value); + Value operand_dim_size = + ShapedType::isDynamic(operand_shape[dim.index()]) + ? b->create(loc, operand, dim.index()).getResult() + : b->create(loc, operand_shape[dim.index()]) + .getResult(); + + // TODO(pifon): Revisit if this cast is needed. Maybe we can use + // tensor for `output_dimensions` as well. + if (!result_dim_size.getType().isIndex()) { + result_dim_size = + b->create(loc, result_dim_size, b->getIndexType()); + } + + // There can be two cases: + // 1) Operand dim == result dim => expansion is not needed => stride := 1. + // 2) Operand dim < result dim => expansion is needed => stride := 0. + Value is_expansion = b->create(loc, CmpIPredicate::slt, + operand_dim_size, result_dim_size); + strides.push_back( + b->create(loc, is_expansion, zero, one)); + + // Size of input dim can be set to the size of the corresponding output + // dimension for both cases. + sizes.push_back(result_dim_size); + } + + // Type-erased memref type with static rank, dynamic sizes and strides. + SmallVector dynamic_layout(operand_shape.size(), + MemRefType::kDynamicStrideOrOffset); + SmallVector dynamic_shape(operand_shape.size(), + MemRefType::kDynamicSize); + auto type_erased_memref_type = MemRefType::get( + dynamic_shape, operand_type.getElementType(), + makeStridedLinearLayoutMap(dynamic_layout, + /*offset=*/0, b->getContext())); + + auto transformed_operand = b->create( + loc, type_erased_memref_type, operand, sizes, strides); + return transformed_operand; + } }; struct HloToLhloReduceOpConverter : public BaseOpConversion {