diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 83773d882a6..579e89ca137 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1201,6 +1201,8 @@ def HLO_PadOp: HLO_Op<"pad", // TODO(b/129422361): PadOp has a custom constructor for HLO. let hasCustomHLOConverter = 1; + + let hasFolder = 1; } def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index c1e3ceb2435..c04e27d50ed 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1787,6 +1787,61 @@ static LogicalResult Verify(PadOp op) { return success(); } +OpFoldResult PadOp::fold(ArrayRef operands) { + // If all padding is zero then it is an identity pad. + auto is_zero = [](const APInt& i) { return i == 0; }; + if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) && + llvm::all_of(edge_padding_high().getIntValues(), is_zero) && + llvm::all_of(interior_padding().getIntValues(), is_zero)) + return operand(); + + // If any padding is negative then it isn't supported by the folder (yet). + auto is_negative = [](const APInt& i) { return i.slt(0); }; + if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) && + llvm::all_of(edge_padding_high().getIntValues(), is_negative) && + llvm::all_of(interior_padding().getIntValues(), is_negative)) + return {}; + + DenseElementsAttr input = operands[0].dyn_cast_or_null(); + DenseElementsAttr padding = operands[1].dyn_cast_or_null(); + RankedTensorType return_type = getType().dyn_cast_or_null(); + if (!input || !input.getType().hasRank() || !padding || !return_type || + !return_type.hasStaticShape()) + return {}; + + // Fill the full result tensor with the padding value. + llvm::SmallVector result(return_type.getNumElements(), + padding.getValue({})); + + auto next_index = [](llvm::SmallVector& index, + llvm::ArrayRef shape) { + for (int64_t i = index.size() - 1; i >= 0; --i) { + ++index[i]; + if (index[i] < shape[i]) return true; + index[i] = 0; + } + return false; + }; + + // Iterate over all elements of the input tensor and copy it to the correct + // location in the output tensor. + llvm::SmallVector index(input.getType().getRank(), 0); + do { + uint64_t linear_index = 0; + uint64_t linear_index_multiplyer = 1; + for (int64_t i = index.size() - 1; i >= 0; --i) { + linear_index += + (edge_padding_low().getValue({uint64_t(i)}) + + index[i] * + (interior_padding().getValue({uint64_t(i)}) + 1)) * + linear_index_multiplyer; + linear_index_multiplyer *= return_type.getShape()[i]; + } + result[linear_index] = input.getValue(index); + } while (next_index(index, input.getType().getShape())); + return DenseElementsAttr::get(return_type, result); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 718a436a5ac..7624ba929ea 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -1437,3 +1437,29 @@ func @scatter_out_of_bound() -> tensor<3x3xi32> { // CHECK: "mhlo.scatter" } +// CHECK-LABEL: @pad_identity_fold +func @pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> { + %0 = constant dense<0.0> : tensor + %1 = "mhlo.pad"(%arg0, %0) { + edge_padding_low = dense<0> : tensor<2xi64>, + edge_padding_high = dense<0> : tensor<2xi64>, + interior_padding = dense<0> : tensor<2xi64> + } : (tensor<5x7xf32>, tensor) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: return %arg0 : tensor<5x7xf32> +} + +// CHECK-LABEL: @pad_fold +func @pad_fold() -> tensor<4x5xi32> { + %0 = constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32> + %1 = constant dense<1> : tensor + %3 = "mhlo.pad"(%0, %1) { + edge_padding_low = dense<[1, 0]> : tensor<2xi64>, + edge_padding_high = dense<[1, 2]> : tensor<2xi64>, + interior_padding = dense<[0, 1]> : tensor<2xi64> + } : (tensor<2x2xi32>, tensor) -> tensor<4x5xi32> + return %3 : tensor<4x5xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] + // CHECK-SAME: ]> : tensor<4x5xi32> +}