Merge pull request #42135 from WindQAQ:parallel-addn
PiperOrigin-RevId: 326010033 Change-Id: I62bb01dca665603b4b7eb3f01415c3f3fcf8b55d
This commit is contained in:
commit
aee9ca5dbd
@ -353,8 +353,16 @@ func @ZerosLike_variant(%arg0: tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf
|
||||
return %0 : tensor<!tf.variant<tensor<2xi32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN
|
||||
func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-LABEL: func @addN_2
|
||||
func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
|
||||
// return %[[SUM0]]
|
||||
%0 = "tf.AddN"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN_3
|
||||
func @addN_3(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
|
||||
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2)
|
||||
// return %[[SUM1]]
|
||||
@ -362,6 +370,27 @@ func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) ->
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN_4
|
||||
func @addN_4(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
|
||||
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3)
|
||||
// CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]])
|
||||
// return %[[SUM2]]
|
||||
%0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN_5
|
||||
func @addN_5(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
|
||||
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3)
|
||||
// CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]])
|
||||
// CHECK: %[[SUM3:.*]] = "tf.AddV2"(%[[SUM2]], %arg4)
|
||||
// return %[[SUM3]]
|
||||
%0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN_variant
|
||||
func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.variant<tensor<2xf32>>>, %arg2: tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>> {
|
||||
// CHECK: tf.AddN
|
||||
|
@ -113,12 +113,42 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
|
||||
|
||||
// Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
|
||||
//
|
||||
// Note that to improve the parallelism, AddN op uses tree-based reduction.
|
||||
// For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
|
||||
//
|
||||
// 0 1 2 3 4
|
||||
// | | | | |
|
||||
// ------- ------- |
|
||||
// | | |
|
||||
// 5 6 |
|
||||
// | | |
|
||||
// ------------- |
|
||||
// | |
|
||||
// 7 |
|
||||
// | |
|
||||
// ----------------
|
||||
// |
|
||||
// 8
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "tf.AddN"(%0, %1, %2)
|
||||
//
|
||||
// is lowered to:
|
||||
//
|
||||
// %sum_0 = "tf.AddV2"(%0, %1)
|
||||
// %result = "tf.AddV2"(%sum_0, %2)
|
||||
// %sum0 = "tf.AddV2"(%0, %1)
|
||||
// %result = "tf.AddV2"(%sum0, %2)
|
||||
//
|
||||
// While
|
||||
//
|
||||
// %result = "tf.AddN"(%0, %1, %2, %3, %4)
|
||||
//
|
||||
// is lowered to:
|
||||
//
|
||||
// %sum0 = "tf.AddV2"(%0, %1)
|
||||
// %sum1 = "tf.AddV2"(%2, %3)
|
||||
// %sum2 = "tf.AddV2"(%sum0, %sum1)
|
||||
// %result = "tf.AddV2"(%sum2, %4)
|
||||
//
|
||||
class LowerAddNOp : public OpRewritePattern<TF::AddNOp> {
|
||||
public:
|
||||
@ -131,14 +161,23 @@ class LowerAddNOp : public OpRewritePattern<TF::AddNOp> {
|
||||
// support variant type so variant types require special handling.
|
||||
if (getElementTypeOrSelf(op.getType()).isa<VariantType>()) return failure();
|
||||
|
||||
// TODO(hinsu): Improve parallelism by splitting operands in two halves and
|
||||
// accumulating them first.
|
||||
Value result = *op.inputs().begin();
|
||||
for (Value operand : llvm::drop_begin(op.inputs(), 1)) {
|
||||
result = rewriter.create<TF::AddV2Op>(op.getLoc(), result, operand);
|
||||
llvm::SmallVector<Value, 4> operands(op.inputs().begin(),
|
||||
op.inputs().end());
|
||||
|
||||
int64_t n = operands.size();
|
||||
// Keep doing tree-based reduction when there are more than one operand.
|
||||
while (n > 1) {
|
||||
for (int64_t i = 0; i < n; i += 2) {
|
||||
// Add two adjacent operands if applicable.
|
||||
operands[i / 2] = (i + 1 < n)
|
||||
? rewriter.create<TF::AddV2Op>(
|
||||
op.getLoc(), operands[i], operands[i + 1])
|
||||
: operands[i];
|
||||
}
|
||||
n = (n + 1) / 2;
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOp(op, operands[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user