From 7dad9545e54f54917c7d31441cda58f94294d9ee Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Thu, 17 Dec 2020 09:25:26 -0800 Subject: [PATCH] Handle operands with zero elements in HLO PadOp folder PiperOrigin-RevId: 348034821 Change-Id: Ie1cce424cd6387e95354e1f5e8e52ea9360bcac2 --- .../mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc | 21 ++++++++++--------- .../compiler/mlir/hlo/tests/canonicalize.mlir | 8 +++++++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 24049eee9b5..3a091a94ea0 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1879,28 +1879,29 @@ OpFoldResult PadOp::fold(ArrayRef operands) { llvm::ArrayRef shape) { for (int64_t i = index.size() - 1; i >= 0; --i) { ++index[i]; - if (index[i] < shape[i]) return true; + if (index[i] < shape[i]) return; index[i] = 0; } - return false; }; // Iterate over all elements of the input tensor and copy it to the correct // location in the output tensor. llvm::SmallVector index(input.getType().getRank(), 0); - do { - uint64_t linear_index = 0; - uint64_t linear_index_multiplyer = 1; + uint64_t num_elements = input.getNumElements(); + for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) { + uint64_t result_idx = 0; + uint64_t idx_multiplyer = 1; for (int64_t i = index.size() - 1; i >= 0; --i) { - linear_index += + result_idx += (edge_padding_low().getValue({uint64_t(i)}) + index[i] * (interior_padding().getValue({uint64_t(i)}) + 1)) * - linear_index_multiplyer; - linear_index_multiplyer *= return_type.getShape()[i]; + idx_multiplyer; + idx_multiplyer *= return_type.getDimSize(i); } - result[linear_index] = input.getValue(index); - } while (next_index(index, input.getType().getShape())); + result[result_idx] = input.getValue(index); + next_index(index, input.getType().getShape()); + } return DenseElementsAttr::get(return_type, result); } diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index e50f7f3d327..8e17895c9ad 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -1515,6 +1515,14 @@ func @pad_fold() -> tensor<4x5xi32> { // CHECK-SAME: ]> : tensor<4x5xi32> } +func @pad_fold_zero_elements() -> tensor<3xi32> { + %0 = mhlo.constant dense<> : tensor<0xi32> + %1 = mhlo.constant dense<7> : tensor + %2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor) -> tensor<3xi32> + return %2 : tensor<3xi32> + // CHECK: mhlo.constant dense<7> : tensor<3xi32> +} + // CHECK-LABEL: @identity_broadcast_reshape func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { %0 = "mhlo.broadcast"(%arg0) {