- Add constant folding for TFL::DivOp

- Support broadcasting for matching rank with values = 1.

PiperOrigin-RevId: 274069438
This commit is contained in:
Karim Nosir 2019-10-10 16:59:34 -07:00 committed by TensorFlower Gardener
parent af35d3b9d4
commit 3bd66c35b0
3 changed files with 108 additions and 52 deletions

View File

@ -96,6 +96,44 @@ Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1,
calculate(lhs.getValue(), rhs.getValue()));
}
// Returns new shape with rank 'new_dims' with padded ones on the
// left if needed.
inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
int new_dims) {
std::vector<int64_t> new_shape(new_dims, 1);
std::copy_backward(old_shape.begin(), old_shape.end(), new_shape.end());
return new_shape;
}
// Helper method that given and 'current_index' representing
// index in broadcasted tensor, get the index in the flat original tensor.
// 'shape' is the original shape with padding to match result shape.
int64_t GetElementIndex(const std::vector<int64_t> &shape,
const std::vector<int64_t> &current_index) {
int64_t ind = 0;
int64_t mul = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
ind += (current_index[i] % shape[i]) * mul;
mul *= shape[i];
}
return ind;
}
// Helper method that increment index represented in 'current_index_ptr'
// in the shape of 'result_shape'.
void IncrementIndex(ArrayRef<int64_t> result_shape,
std::vector<int64_t> *current_index_ptr) {
std::vector<int64_t> &current_index = *current_index_ptr;
for (int i = result_shape.size() - 1; i >= 0; --i) {
current_index[i]++;
if (current_index[i] == result_shape[i]) {
current_index[i] = 0;
} else {
break;
}
}
}
/// Performs const folding `calculate` with broadcast behavior on the two
/// attributes `operand1` and `operand2` and returns the result if possible.
/// This function assumes the both operands are verified to have value
@ -107,23 +145,10 @@ template <class AttrElementT,
Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
DenseElementsAttr rhs,
const CalculationT &calculate) {
auto type = result_type.cast<ShapedType>();
if (lhs.getType() != rhs.getType()) {
// We only support the case that one of the operand's dimensions are
// a perfect suffix of the other.
// TODO: support the general broadcast behavior.
auto lhs_shape = lhs.getType().getShape();
auto rhs_shape = rhs.getType().getShape();
if (IsTrailingDimensions(lhs_shape, rhs_shape)) {
if (!type.hasStaticShape()) type = rhs.getType();
} else if (IsTrailingDimensions(rhs_shape, lhs_shape)) {
if (!type.hasStaticShape()) type = lhs.getType();
} else {
return {};
}
} else if (!type.hasStaticShape()) {
type = lhs.getType();
auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())
.dyn_cast_or_null<ShapedType>();
if (!type) {
return {};
}
const bool rhs_is_splat = rhs.isSplat();
@ -139,14 +164,7 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
return DenseElementsAttr::get(type, element_result);
}
auto lhs_num_elements = lhs.getType().getNumElements();
auto rhs_num_elements = rhs.getType().getNumElements();
auto num_elements = std::max(lhs_num_elements, rhs_num_elements);
// We assume the arguments have broadcast-compatible types. Make sure again.
assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements);
assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0);
auto num_elements = type.getNumElements();
SmallVector<ElementValueT, 16> lhs_old_values;
SmallVector<ElementValueT, 16> rhs_old_values;
if (lhs_is_splat)
@ -157,31 +175,32 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs,
rhs_old_values.push_back(rhs.getSplatValue<ElementValueT>());
else
rhs_old_values = llvm::to_vector<16>(rhs.getValues<ElementValueT>());
SmallVector<ElementValueT, 16> new_values;
new_values.reserve(num_elements);
const auto result_shape = type.getShape();
std::vector<int64_t> current_index(type.getRank(), 0);
// Create the new shape with ones padded to the left.
std::vector<int64_t> lhs_new_shape =
GetPaddedShape(lhs.getType().getShape(), type.getRank());
std::vector<int64_t> rhs_new_shape =
GetPaddedShape(rhs.getType().getShape(), type.getRank());
// Add each pair of the corresponding values in the dense elements
// attributes.
for (int i = 0; i < num_elements; ++i) {
// We only support a degenerated case here: the dimensions in one operand's
// shape is a perfect suffix to the other operand. Then conceptually it's
// similar to broadcasting a scalar to a 1-D vector.
// TODO: support the general broadcast behavior.
// We are tiling the operand with less elements an integral times to match
// the operand with more elements. We don't care which operand has less
// elements here because we are iterating its elements in circles, which can
// be achieved using the result index modulo the element count. For the
// operand with more elements, since the result has the same number of
// elements, we are only going over its elements once. The modulo operation
// also works for that.
int lhs_index = lhs_is_splat ? 0 : (i % lhs_num_elements);
int rhs_index = rhs_is_splat ? 0 : (i % rhs_num_elements);
for (int64_t i = 0; i < num_elements; ++i) {
// current_index represents the index
// in the N-dimension tensor. GetElementIndex returns
// the index in the flat representation of the original tensor
// to use.
int64_t lhs_index =
lhs_is_splat ? 0 : GetElementIndex(lhs_new_shape, current_index);
int64_t rhs_index =
rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index);
new_values.push_back(
calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index]));
IncrementIndex(result_shape, &current_index);
}
return DenseElementsAttr::get(type, new_values);
}
@ -308,7 +327,7 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
//===----------------------------------------------------------------------===//
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
// TODO(b/142478136): Handle fused ops.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a + b; },
@ -636,13 +655,26 @@ static void BuildGatherOp(Builder *builder, OperationState &result,
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
// TODO(b/142478136): Handle fused ops.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a * b; },
[](APInt a, APInt b) { return a * b; }, getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
// TODO(b/142478136): Handle fused ops.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a / b; },
[](APInt a, APInt b) { return a.sdiv(b); },
getOperation()->isCommutative());
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
@ -922,7 +954,7 @@ static LogicalResult Verify(SliceOp op) {
//===----------------------------------------------------------------------===//
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
// Skip fused ops for now.
// TODO(b/142478136): Handle fused ops.
if (fused_activation_function() != "NONE") return {};
return ConstFoldBinaryOp(
getType(), operands, [](APFloat a, APFloat b) { return a - b; },

View File

@ -908,6 +908,8 @@ def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> {
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
let hasOptions = 1;
let hasFolder = 1;
}
def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {

View File

@ -226,11 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// We don't support this case yet.
// %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: %0 = "tfl.add"
// CHECK: return %0
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %cst
}
// CHECK-LABEL: @add_dense_splat_float
@ -299,9 +296,8 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// We don't support this case yet.
// CHECK: %0 = "tfl.add"
// CHECK: return %0
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %cst
}
// CHECK-LABEL: @rank
@ -555,3 +551,29 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
}
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
%cst_0 = constant dense<[[1.5, -2.5]]> : tensor<1x2xf32>
%cst_1 = constant dense<[[-3.], [4.]]> : tensor<2x1xf32>
%0 = "tfl.div"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %cst
}
// CHECK-LABEL: @div_dense_different_rank
func @div_dense_different_rank() -> tensor<1x2x2xf32> {
%cst_0 = constant dense<[[[1.0],[2.0]]]> : tensor<1x2x1xf32>
%cst_1 = constant dense<[[2.0, 3.0]]> : tensor<1x2xf32>
%0 = "tfl.div"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2x1xf32>, tensor<1x2xf32>) -> tensor<1x2x2xf32>
return %0 : tensor<1x2x2xf32>
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %cst
}