Merge pull request #42135 from WindQAQ:parallel-addn

PiperOrigin-RevId: 326010033
Change-Id: I62bb01dca665603b4b7eb3f01415c3f3fcf8b55d
This commit is contained in:
TensorFlower Gardener 2020-08-11 06:43:15 -07:00
commit aee9ca5dbd
2 changed files with 78 additions and 10 deletions

View File

@ -353,8 +353,16 @@ func @ZerosLike_variant(%arg0: tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf
return %0 : tensor<!tf.variant<tensor<2xi32>>>
}
// CHECK-LABEL: func @addN
func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @addN_2
func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
// return %[[SUM0]]
%0 = "tf.AddN"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addN_3
func @addN_3(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2)
// return %[[SUM1]]
@ -362,6 +370,27 @@ func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) ->
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addN_4
func @addN_4(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3)
// CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]])
// return %[[SUM2]]
%0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addN_5
func @addN_5(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3)
// CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]])
// CHECK: %[[SUM3:.*]] = "tf.AddV2"(%[[SUM2]], %arg4)
// return %[[SUM3]]
%0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addN_variant
func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.variant<tensor<2xf32>>>, %arg2: tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>> {
// CHECK: tf.AddN

View File

@ -113,12 +113,42 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
// Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
//
// Note that to improve the parallelism, AddN op uses tree-based reduction.
// For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
//
// 0 1 2 3 4
// | | | | |
// ------- ------- |
// | | |
// 5 6 |
// | | |
// ------------- |
// | |
// 7 |
// | |
// ----------------
// |
// 8
//
// Example:
//
// %result = "tf.AddN"(%0, %1, %2)
//
// is lowered to:
//
// %sum_0 = "tf.AddV2"(%0, %1)
// %result = "tf.AddV2"(%sum_0, %2)
// %sum0 = "tf.AddV2"(%0, %1)
// %result = "tf.AddV2"(%sum0, %2)
//
// While
//
// %result = "tf.AddN"(%0, %1, %2, %3, %4)
//
// is lowered to:
//
// %sum0 = "tf.AddV2"(%0, %1)
// %sum1 = "tf.AddV2"(%2, %3)
// %sum2 = "tf.AddV2"(%sum0, %sum1)
// %result = "tf.AddV2"(%sum2, %4)
//
class LowerAddNOp : public OpRewritePattern<TF::AddNOp> {
public:
@ -131,14 +161,23 @@ class LowerAddNOp : public OpRewritePattern<TF::AddNOp> {
// support variant type so variant types require special handling.
if (getElementTypeOrSelf(op.getType()).isa<VariantType>()) return failure();
// TODO(hinsu): Improve parallelism by splitting operands in two halves and
// accumulating them first.
Value result = *op.inputs().begin();
for (Value operand : llvm::drop_begin(op.inputs(), 1)) {
result = rewriter.create<TF::AddV2Op>(op.getLoc(), result, operand);
llvm::SmallVector<Value, 4> operands(op.inputs().begin(),
op.inputs().end());
int64_t n = operands.size();
// Keep doing tree-based reduction when there are more than one operand.
while (n > 1) {
for (int64_t i = 0; i < n; i += 2) {
// Add two adjacent operands if applicable.
operands[i / 2] = (i + 1 < n)
? rewriter.create<TF::AddV2Op>(
op.getLoc(), operands[i], operands[i + 1])
: operands[i];
}
n = (n + 1) / 2;
}
rewriter.replaceOp(op, result);
rewriter.replaceOp(op, operands[0]);
return success();
}
};