diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 7bf7780d20c..9213469e2d3 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -96,6 +96,44 @@ Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1, calculate(lhs.getValue(), rhs.getValue())); } +// Returns new shape with rank 'new_dims' with padded ones on the +// left if needed. +inline std::vector GetPaddedShape(ArrayRef old_shape, + int new_dims) { + std::vector new_shape(new_dims, 1); + std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end()); + return new_shape; +} + +// Helper method that given and 'current_index' representing +// index in broadcasted tensor, get the index in the flat original tensor. +// 'shape' is the original shape with padding to match result shape. +int64_t GetElementIndex(const std::vector &shape, + const std::vector ¤t_index) { + int64_t ind = 0; + int64_t mul = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + ind += (current_index[i] % shape[i]) * mul; + mul *= shape[i]; + } + return ind; +} + +// Helper method that increment index represented in 'current_index_ptr' +// in the shape of 'result_shape'. +void IncrementIndex(ArrayRef result_shape, + std::vector *current_index_ptr) { + std::vector ¤t_index = *current_index_ptr; + for (int i = result_shape.size() - 1; i >= 0; --i) { + current_index[i]++; + if (current_index[i] == result_shape[i]) { + current_index[i] = 0; + } else { + break; + } + } +} + /// Performs const folding `calculate` with broadcast behavior on the two /// attributes `operand1` and `operand2` and returns the result if possible. /// This function assumes the both operands are verified to have value @@ -107,23 +145,10 @@ template (); - - if (lhs.getType() != rhs.getType()) { - // We only support the case that one of the operand's dimensions are - // a perfect suffix of the other. - // TODO: support the general broadcast behavior. - auto lhs_shape = lhs.getType().getShape(); - auto rhs_shape = rhs.getType().getShape(); - if (IsTrailingDimensions(lhs_shape, rhs_shape)) { - if (!type.hasStaticShape()) type = rhs.getType(); - } else if (IsTrailingDimensions(rhs_shape, lhs_shape)) { - if (!type.hasStaticShape()) type = lhs.getType(); - } else { - return {}; - } - } else if (!type.hasStaticShape()) { - type = lhs.getType(); + auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()) + .dyn_cast_or_null(); + if (!type) { + return {}; } const bool rhs_is_splat = rhs.isSplat(); @@ -139,14 +164,7 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, return DenseElementsAttr::get(type, element_result); } - auto lhs_num_elements = lhs.getType().getNumElements(); - auto rhs_num_elements = rhs.getType().getNumElements(); - auto num_elements = std::max(lhs_num_elements, rhs_num_elements); - - // We assume the arguments have broadcast-compatible types. Make sure again. - assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements); - assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0); - + auto num_elements = type.getNumElements(); SmallVector lhs_old_values; SmallVector rhs_old_values; if (lhs_is_splat) @@ -157,31 +175,32 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, rhs_old_values.push_back(rhs.getSplatValue()); else rhs_old_values = llvm::to_vector<16>(rhs.getValues()); - SmallVector new_values; new_values.reserve(num_elements); + const auto result_shape = type.getShape(); + std::vector current_index(type.getRank(), 0); + // Create the new shape with ones padded to the left. + std::vector lhs_new_shape = + GetPaddedShape(lhs.getType().getShape(), type.getRank()); + std::vector rhs_new_shape = + GetPaddedShape(rhs.getType().getShape(), type.getRank()); // Add each pair of the corresponding values in the dense elements // attributes. - for (int i = 0; i < num_elements; ++i) { - // We only support a degenerated case here: the dimensions in one operand's - // shape is a perfect suffix to the other operand. Then conceptually it's - // similar to broadcasting a scalar to a 1-D vector. - // TODO: support the general broadcast behavior. - // We are tiling the operand with less elements an integral times to match - // the operand with more elements. We don't care which operand has less - // elements here because we are iterating its elements in circles, which can - // be achieved using the result index modulo the element count. For the - // operand with more elements, since the result has the same number of - // elements, we are only going over its elements once. The modulo operation - // also works for that. - int lhs_index = lhs_is_splat ? 0 : (i % lhs_num_elements); - int rhs_index = rhs_is_splat ? 0 : (i % rhs_num_elements); + for (int64_t i = 0; i < num_elements; ++i) { + // current_index represents the index + // in the N-dimension tensor. GetElementIndex returns + // the index in the flat representation of the original tensor + // to use. + int64_t lhs_index = + lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index); + int64_t rhs_index = + rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index); new_values.push_back( calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index])); + IncrementIndex(result_shape, ¤t_index); } - return DenseElementsAttr::get(type, new_values); } @@ -308,7 +327,7 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result, //===----------------------------------------------------------------------===// OpFoldResult AddOp::fold(ArrayRef operands) { - // Skip fused ops for now. + // TODO(b/142478136): Handle fused ops. if (fused_activation_function() != "NONE") return {}; return ConstFoldBinaryOp( getType(), operands, [](APFloat a, APFloat b) { return a + b; }, @@ -636,13 +655,26 @@ static void BuildGatherOp(Builder *builder, OperationState &result, //===----------------------------------------------------------------------===// OpFoldResult MulOp::fold(ArrayRef operands) { - // Skip fused ops for now. + // TODO(b/142478136): Handle fused ops. if (fused_activation_function() != "NONE") return {}; return ConstFoldBinaryOp( getType(), operands, [](APFloat a, APFloat b) { return a * b; }, [](APInt a, APInt b) { return a * b; }, getOperation()->isCommutative()); } +//===----------------------------------------------------------------------===// +// DivOp +//===----------------------------------------------------------------------===// + +OpFoldResult DivOp::fold(ArrayRef operands) { + // TODO(b/142478136): Handle fused ops. + if (fused_activation_function() != "NONE") return {}; + return ConstFoldBinaryOp( + getType(), operands, [](APFloat a, APFloat b) { return a / b; }, + [](APInt a, APInt b) { return a.sdiv(b); }, + getOperation()->isCommutative()); +} + //===----------------------------------------------------------------------===// // PackOp //===----------------------------------------------------------------------===// @@ -922,7 +954,7 @@ static LogicalResult Verify(SliceOp op) { //===----------------------------------------------------------------------===// OpFoldResult SubOp::fold(ArrayRef operands) { - // Skip fused ops for now. + // TODO(b/142478136): Handle fused ops. if (fused_activation_function() != "NONE") return {}; return ConstFoldBinaryOp( getType(), operands, [](APFloat a, APFloat b) { return a - b; }, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 8c80fd0dfb0..0598365284f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -908,6 +908,8 @@ def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> { let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; let hasOptions = 1; + + let hasFolder = 1; } def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> { diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 5853df8e467..99b2c86a0c6 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -226,11 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> { %0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> - -// We don't support this case yet. -// %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> -// CHECK: %0 = "tfl.add" -// CHECK: return %0 +// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> +// CHECK: return %cst } // CHECK-LABEL: @add_dense_splat_float @@ -299,9 +296,8 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// We don't support this case yet. -// CHECK: %0 = "tfl.add" -// CHECK: return %0 +// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> +// CHECK: return %cst } // CHECK-LABEL: @rank @@ -555,3 +551,29 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> { // CHECK-NOT: "tfl.concatenation" // CHECK: return [[cst]] } + +// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n +func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { + %cst_0 = constant dense<[[1.5, -2.5]]> : tensor<1x2xf32> + %cst_1 = constant dense<[[-3.], [4.]]> : tensor<2x1xf32> + + %0 = "tfl.div"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> + + return %0 : tensor<2x2xf32> + +// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> +// CHECK: return %cst +} + +// CHECK-LABEL: @div_dense_different_rank +func @div_dense_different_rank() -> tensor<1x2x2xf32> { + %cst_0 = constant dense<[[[1.0],[2.0]]]> : tensor<1x2x1xf32> + %cst_1 = constant dense<[[2.0, 3.0]]> : tensor<1x2xf32> + + %0 = "tfl.div"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2x1xf32>, tensor<1x2xf32>) -> tensor<1x2x2xf32> + + return %0 : tensor<1x2x2xf32> + +// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> +// CHECK: return %cst +}