- Add constant folding for TFL::DivOp
- Support broadcasting for matching rank with values = 1. PiperOrigin-RevId: 274069438
This commit is contained in:
parent
af35d3b9d4
commit
3bd66c35b0
@ -96,6 +96,44 @@ Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1,
|
||||
calculate(lhs.getValue(), rhs.getValue()));
|
||||
}
|
||||
|
||||
// Returns new shape with rank 'new_dims' with padded ones on the
|
||||
// left if needed.
|
||||
inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
|
||||
int new_dims) {
|
||||
std::vector<int64_t> new_shape(new_dims, 1);
|
||||
std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
// Helper method that given and 'current_index' representing
|
||||
// index in broadcasted tensor, get the index in the flat original tensor.
|
||||
// 'shape' is the original shape with padding to match result shape.
|
||||
int64_t GetElementIndex(const std::vector<int64_t> &shape,
|
||||
const std::vector<int64_t> ¤t_index) {
|
||||
int64_t ind = 0;
|
||||
int64_t mul = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
ind += (current_index[i] % shape[i]) * mul;
|
||||
mul *= shape[i];
|
||||
}
|
||||
return ind;
|
||||
}
|
||||
|
||||
// Helper method that increment index represented in 'current_index_ptr'
|
||||
// in the shape of 'result_shape'.
|
||||
void IncrementIndex(ArrayRef<int64_t> result_shape,
|
||||
std::vector<int64_t> *current_index_ptr) {
|
||||
std::vector<int64_t> ¤t_index = *current_index_ptr;
|
||||
for (int i = result_shape.size() - 1; i >= 0; --i) {
|
||||
current_index[i]++;
|
||||
if (current_index[i] == result_shape[i]) {
|
||||
current_index[i] = 0;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs const folding `calculate` with broadcast behavior on the two
|
||||
/// attributes `operand1` and `operand2` and returns the result if possible.
|
||||
/// This function assumes the both operands are verified to have value
|
||||
@ -107,23 +145,10 @@ template <class AttrElementT,
|
||||
Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
|
||||
DenseElementsAttr rhs,
|
||||
const CalculationT &calculate) {
|
||||
auto type = result_type.cast<ShapedType>();
|
||||
|
||||
if (lhs.getType() != rhs.getType()) {
|
||||
// We only support the case that one of the operand's dimensions are
|
||||
// a perfect suffix of the other.
|
||||
// TODO: support the general broadcast behavior.
|
||||
auto lhs_shape = lhs.getType().getShape();
|
||||
auto rhs_shape = rhs.getType().getShape();
|
||||
if (IsTrailingDimensions(lhs_shape, rhs_shape)) {
|
||||
if (!type.hasStaticShape()) type = rhs.getType();
|
||||
} else if (IsTrailingDimensions(rhs_shape, lhs_shape)) {
|
||||
if (!type.hasStaticShape()) type = lhs.getType();
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
} else if (!type.hasStaticShape()) {
|
||||
type = lhs.getType();
|
||||
auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
|
||||
.dyn_cast_or_null<ShapedType>();
|
||||
if (!type) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const bool rhs_is_splat = rhs.isSplat();
|
||||
@ -139,14 +164,7 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
|
||||
return DenseElementsAttr::get(type, element_result);
|
||||
}
|
||||
|
||||
auto lhs_num_elements = lhs.getType().getNumElements();
|
||||
auto rhs_num_elements = rhs.getType().getNumElements();
|
||||
auto num_elements = std::max(lhs_num_elements, rhs_num_elements);
|
||||
|
||||
// We assume the arguments have broadcast-compatible types. Make sure again.
|
||||
assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements);
|
||||
assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0);
|
||||
|
||||
auto num_elements = type.getNumElements();
|
||||
SmallVector<ElementValueT, 16> lhs_old_values;
|
||||
SmallVector<ElementValueT, 16> rhs_old_values;
|
||||
if (lhs_is_splat)
|
||||
@ -157,31 +175,32 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
|
||||
rhs_old_values.push_back(rhs.getSplatValue<ElementValueT>());
|
||||
else
|
||||
rhs_old_values = llvm::to_vector<16>(rhs.getValues<ElementValueT>());
|
||||
|
||||
SmallVector<ElementValueT, 16> new_values;
|
||||
new_values.reserve(num_elements);
|
||||
const auto result_shape = type.getShape();
|
||||
std::vector<int64_t> current_index(type.getRank(), 0);
|
||||
// Create the new shape with ones padded to the left.
|
||||
std::vector<int64_t> lhs_new_shape =
|
||||
GetPaddedShape(lhs.getType().getShape(), type.getRank());
|
||||
std::vector<int64_t> rhs_new_shape =
|
||||
GetPaddedShape(rhs.getType().getShape(), type.getRank());
|
||||
|
||||
// Add each pair of the corresponding values in the dense elements
|
||||
// attributes.
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
// We only support a degenerated case here: the dimensions in one operand's
|
||||
// shape is a perfect suffix to the other operand. Then conceptually it's
|
||||
// similar to broadcasting a scalar to a 1-D vector.
|
||||
// TODO: support the general broadcast behavior.
|
||||
// We are tiling the operand with less elements an integral times to match
|
||||
// the operand with more elements. We don't care which operand has less
|
||||
// elements here because we are iterating its elements in circles, which can
|
||||
// be achieved using the result index modulo the element count. For the
|
||||
// operand with more elements, since the result has the same number of
|
||||
// elements, we are only going over its elements once. The modulo operation
|
||||
// also works for that.
|
||||
int lhs_index = lhs_is_splat ? 0 : (i % lhs_num_elements);
|
||||
int rhs_index = rhs_is_splat ? 0 : (i % rhs_num_elements);
|
||||
for (int64_t i = 0; i < num_elements; ++i) {
|
||||
// current_index represents the index
|
||||
// in the N-dimension tensor. GetElementIndex returns
|
||||
// the index in the flat representation of the original tensor
|
||||
// to use.
|
||||
int64_t lhs_index =
|
||||
lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
|
||||
int64_t rhs_index =
|
||||
rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
|
||||
|
||||
new_values.push_back(
|
||||
calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index]));
|
||||
IncrementIndex(result_shape, ¤t_index);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(type, new_values);
|
||||
}
|
||||
|
||||
@ -308,7 +327,7 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Skip fused ops for now.
|
||||
// TODO(b/142478136): Handle fused ops.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a + b; },
|
||||
@ -636,13 +655,26 @@ static void BuildGatherOp(Builder *builder, OperationState &result,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Skip fused ops for now.
|
||||
// TODO(b/142478136): Handle fused ops.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a * b; },
|
||||
[](APInt a, APInt b) { return a * b; }, getOperation()->isCommutative());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
||||
// TODO(b/142478136): Handle fused ops.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a / b; },
|
||||
[](APInt a, APInt b) { return a.sdiv(b); },
|
||||
getOperation()->isCommutative());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -922,7 +954,7 @@ static LogicalResult Verify(SliceOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Skip fused ops for now.
|
||||
// TODO(b/142478136): Handle fused ops.
|
||||
if (fused_activation_function() != "NONE") return {};
|
||||
return ConstFoldBinaryOp(
|
||||
getType(), operands, [](APFloat a, APFloat b) { return a - b; },
|
||||
|
@ -908,6 +908,8 @@ def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> {
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
@ -226,11 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
return %0 : tensor<2x2xi32>
|
||||
|
||||
// We don't support this case yet.
|
||||
// %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: %0 = "tfl.add"
|
||||
// CHECK: return %0
|
||||
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_splat_float
|
||||
@ -299,9 +296,8 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// We don't support this case yet.
|
||||
// CHECK: %0 = "tfl.add"
|
||||
// CHECK: return %0
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rank
|
||||
@ -555,3 +551,29 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
|
||||
func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
%cst_0 = constant dense<[[1.5, -2.5]]> : tensor<1x2xf32>
|
||||
%cst_1 = constant dense<[[-3.], [4.]]> : tensor<2x1xf32>
|
||||
|
||||
%0 = "tfl.div"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div_dense_different_rank
|
||||
func @div_dense_different_rank() -> tensor<1x2x2xf32> {
|
||||
%cst_0 = constant dense<[[[1.0],[2.0]]]> : tensor<1x2x1xf32>
|
||||
%cst_1 = constant dense<[[2.0, 3.0]]> : tensor<1x2xf32>
|
||||
|
||||
%0 = "tfl.div"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2x1xf32>, tensor<1x2xf32>) -> tensor<1x2x2xf32>
|
||||
|
||||
return %0 : tensor<1x2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %cst
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user