From 78d28301bcde34028a201a6288e6d409d66e254a Mon Sep 17 00:00:00 2001 From: Mingsheng Hong Date: Wed, 10 Feb 2021 14:03:14 -0800 Subject: [PATCH] Enhanced the HoistCwiseBinaryOutOfConcat rewrite pattern to synthesize input binary ops when needed. Where a majority fraction of the inputs to a ConcatV2 op are produced by a binary op kind B (e.g. tf.Mul), for the remaining inputs, we now synthesize tf.Mul for each of them respectively, with the identity scalar const tensor 1.0. This allows us to apply HoistCwiseBinaryOutOfConcat to a wide range of workloads. PiperOrigin-RevId: 356821582 Change-Id: I252ba9abc3248953a88be978c98c3a3d4b55b089 --- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 117 ++++++++++++++++-- .../mlir/tensorflow/tests/canonicalize.mlir | 44 +++++++ 2 files changed, 148 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 6a0ab9a67fa..3bb7861537b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -1039,6 +1039,14 @@ LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite( // %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis) // %2 = tf.Mul(%0, %1) // +// If a minor fraction of the Concat inputs are not of the same binary op kind +// (tf.Mul in the above example), we will synthesize the binary ops for those +// inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a +// tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to +// float32 tensors. +// TODO(hongm): Implement this op synthesis optimization for other dtypes if +// needed. +// // Because coefficient-wise binary operations support implicit broadcasting, we // should be very careful with this optimization, and do not accidentally // produce incorrect concat operations. @@ -1057,11 +1065,17 @@ class HoistCwiseBinaryOutOfConcat : public OpRewritePattern { int64_t rhs_axis; Type lhs_concat_type; Type rhs_concat_type; + int scalar_operand_idx; // can be 0 or 1 for the binary op's operands. }; // Returns parameters of a binary op hoisting out of concatenation if all of // the operands are in one of the compatible configurations. - Optional GetHoistParams(TF::ConcatV2Op op, int64_t axis) const; + // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul), + // except from the ones in `exceptions`. In that case, we can synthesize that + // binary op kind for the values in `exceptions`. + Optional GetHoistParams( + TF::ConcatV2Op op, int64_t axis, + const llvm::SmallDenseMap &exceptions) const; }; LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( @@ -1078,24 +1092,86 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( // on the channels dim for NCHW layout as axis=-2. if (axis < 0) return failure(); - // All concat operands must be defined by ops. + // All concat operands must be defined by ops of the same kind (e.g. tf.Mul), + // or some other ops that we might convert to using the same op kind above + // (e.g. converting op A to tf.Mul(A, 1.0)) + // TODO(hongm): generalize the code here to support cases where the first arg + // has no defining op (e.g. might be a block arg). Operation *first_arg_op = op.values().front().getDefiningOp(); if (first_arg_op == nullptr) return failure(); // All concat operands must be produced by the coeff-wise binary operation. if (!first_arg_op->hasTrait()) return failure(); - // All concat operands must be defined by the op of same kind. - bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool { + // All concat operands must be defined by the op of same kind, except for a + // minor portion which we track in `exceptions`. + // Map from the operands to operand indices. + llvm::SmallDenseMap exceptions; + unsigned operand_idx = 0; + for (Value arg : op.values()) { Operation *arg_op = arg.getDefiningOp(); - return arg_op && arg_op->getName() == first_arg_op->getName(); - }); - if (!args_same_op) return failure(); + if (arg_op && arg_op->getName() == first_arg_op->getName()) { + ++operand_idx; + continue; + } + exceptions[arg] = operand_idx++; + } + // Recall those inputs to the concat op that are not produced by a binary op + // of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If + // there are too many exceptions, it might not be cost effective to apply the + // concat hoisting optimization here. + // Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1 + // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1 + // out of 3, we do. + const float exception_pct_threshold = 0.5; + if (static_cast(op.values().size()) * exception_pct_threshold <= + exceptions.size()) + return failure(); // Compute binary operands hoist parameters. - auto hoist_params = GetHoistParams(op, axis); + auto hoist_params = GetHoistParams(op, axis, exceptions); if (!hoist_params.hasValue()) return failure(); + // Process `exceptions`: For each value there, synthesize a binary op of the + // above kind, so that the concat hoisting optimization can still apply. + if (!exceptions.empty()) { + int identity_val; + if (isa(first_arg_op) || isa(first_arg_op)) + identity_val = 0; + else if (isa(first_arg_op) || isa(first_arg_op) || + isa(first_arg_op)) + identity_val = 1; + else + return failure(); + DenseElementsAttr const_attr; + auto scalar_tensor_type = + first_arg_op->getOperand(hoist_params->scalar_operand_idx) + .getType() + .dyn_cast(); + Type scalar_dtype = scalar_tensor_type.getElementType(); + if (scalar_dtype.isa()) + const_attr = DenseElementsAttr::get(scalar_tensor_type, + static_cast(identity_val)); + else + return failure(); + + // All checks are passes, and we now prepare for rewrite. + auto identity_const = rewriter.create(loc, const_attr); + for (const auto &kv : exceptions) { + assert(!hoist_params->lhs_args[kv.second]); + assert(!hoist_params->rhs_args[kv.second]); + + if (hoist_params->scalar_operand_idx == 1) { + hoist_params->lhs_args[kv.second] = kv.first; + hoist_params->rhs_args[kv.second] = identity_const; + } else { + assert(hoist_params->scalar_operand_idx == 0); + hoist_params->lhs_args[kv.second] = identity_const; + hoist_params->rhs_args[kv.second] = kv.first; + } + } + } + // New lhs and rhs concatenation axis. auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64)); auto lhs_axis = rewriter.create( @@ -1122,12 +1198,14 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( } Optional -HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op, - int64_t axis) const { +HoistCwiseBinaryOutOfConcat::GetHoistParams( + TF::ConcatV2Op op, int64_t axis, + const llvm::SmallDenseMap &exceptions) const { assert(axis >= 0); // Collects lhs or rhs arguments of concat op operands. auto args = [&](int operand_idx) -> SmallVector { auto range = llvm::map_range(op.values(), [&](Value arg) { + if (exceptions.count(arg)) return Value(); return arg.getDefiningOp()->getOperand(operand_idx); }); return {range.begin(), range.end()}; @@ -1137,6 +1215,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op, // of `axis + 1` rank and axis dim has size `1`. auto is_all_tensors = [&](int operand_idx, int axis) -> bool { return llvm::all_of(op.values(), [&](Value arg) -> bool { + if (exceptions.count(arg)) return true; auto operand = arg.getDefiningOp()->getOperand(operand_idx); auto ranked = operand.getType().dyn_cast(); return ranked && ranked.getRank() == (axis + 1) && @@ -1147,6 +1226,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op, // Returns true if all binary ops operands at `operand_idx` index are scalars. auto is_all_scalars = [&](int operand_idx) -> bool { return llvm::all_of(op.values(), [&](Value arg) -> bool { + if (exceptions.count(arg)) return true; auto operand = arg.getDefiningOp()->getOperand(operand_idx); auto ranked = operand.getType().dyn_cast(); return ranked && ranked.hasRank() && ranked.getRank() == 0; @@ -1168,13 +1248,24 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op, if (is_all_tensors(0, axis) && is_all_scalars(1)) { std::array rhs_dims{static_cast(op.values().size())}; auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType()); - return HoistParams{args(0), args(1), axis, 0, op.getType(), rhs_type}; + return HoistParams{args(0), + args(1), + axis, + 0, + op.getType(), + rhs_type, + /*scalar_operand_idx=*/1}; } else if (is_all_tensors(1, axis) && is_all_scalars(0)) { std::array lhs_dims{static_cast(op.values().size())}; auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType()); - return HoistParams{args(0), args(1), 0, axis, lhs_type, op.getType()}; + return HoistParams{args(0), + args(1), + 0, + axis, + lhs_type, + op.getType(), + /*scalar_operand_idx=*/0}; } - return None; } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 841e6ddb1cf..90117eea3a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -226,6 +226,50 @@ func @testConcatCwiseBinaryNegativeAxis(%arg0: tensor, return %3 : tensor<2xf32> } +// Synthesize binary ops when 1 of the 3 concat inputs is a non-binary op. +// CHECK-LABEL: testConcatCwiseBinarySynthMulOp3Inputs +func @testConcatCwiseBinarySynthMulOp3Inputs(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 1.000000e+00]> + // CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2, + // CHECK: "tf.Mul"(%[[CONCAT]], %[[CONST]]) + %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %mul0_const = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor, tensor) -> tensor + %mul1_const = "tf.Const"() { value = dense<3.0> : tensor } : () -> tensor + %mul1 = "tf.Mul"(%arg1, %mul1_const) : (tensor, tensor) -> tensor + %ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor, tensor, tensor, tensor) -> tensor + + return %ret : tensor +} + +// Similar to to the above, with tf.Sub as the binary op kind. +func @testConcatCwiseBinarySynthSubOp3Inputs(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 0.000000e+00]> + // CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2, + // CHECK: "tf.Sub"(%[[CONCAT]], %[[CONST]]) + %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %mul0_const = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %mul0 = "tf.Sub"(%arg0, %mul0_const) : (tensor, tensor) -> tensor + %mul1_const = "tf.Const"() { value = dense<3.0> : tensor } : () -> tensor + %mul1 = "tf.Sub"(%arg1, %mul1_const) : (tensor, tensor) -> tensor + %ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor, tensor, tensor, tensor) -> tensor + + return %ret : tensor +} + +// Do not synthesize binary ops when 1 of the 2 concat inputs is a non-binary op. +// CHECK-LABEL: testConcatCwiseBinarySynthMulOp2Inputs +func @testConcatCwiseBinarySynthMulOp2Inputs(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[MUL:.*]] = "tf.Mul"(%arg0, + // CHECK: "tf.ConcatV2"(%[[MUL]], %arg1, + %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %mul0_const = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor, tensor) -> tensor + %ret = "tf.ConcatV2"(%mul0, %arg1, %axis) : (tensor, tensor, tensor) -> tensor + + return %ret : tensor +} + // CHECK-LABEL: testLogOfSoftmax func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>