Enhanced the HoistCwiseBinaryOutOfConcat rewrite pattern to synthesize input binary ops when needed.

Where a majority fraction of the inputs to a ConcatV2 op are produced by a binary op kind B (e.g. tf.Mul), for the remaining inputs, we now synthesize tf.Mul for each of them respectively, with the identity scalar const tensor 1.0. This allows us to apply HoistCwiseBinaryOutOfConcat to a wide range of workloads.

PiperOrigin-RevId: 356821582
Change-Id: I252ba9abc3248953a88be978c98c3a3d4b55b089
This commit is contained in:
Mingsheng Hong 2021-02-10 14:03:14 -08:00 committed by TensorFlower Gardener
parent 7e60a5de30
commit 78d28301bc
2 changed files with 148 additions and 13 deletions

View File

@ -1039,6 +1039,14 @@ LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
// %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
// %2 = tf.Mul(%0, %1)
//
// If a minor fraction of the Concat inputs are not of the same binary op kind
// (tf.Mul in the above example), we will synthesize the binary ops for those
// inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
// tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
// float32 tensors.
// TODO(hongm): Implement this op synthesis optimization for other dtypes if
// needed.
//
// Because coefficient-wise binary operations support implicit broadcasting, we
// should be very careful with this optimization, and do not accidentally
// produce incorrect concat operations.
@ -1057,11 +1065,17 @@ class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
int64_t rhs_axis;
Type lhs_concat_type;
Type rhs_concat_type;
int scalar_operand_idx; // can be 0 or 1 for the binary op's operands.
};
// Returns parameters of a binary op hoisting out of concatenation if all of
// the operands are in one of the compatible configurations.
Optional<HoistParams> GetHoistParams(TF::ConcatV2Op op, int64_t axis) const;
// All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
// except from the ones in `exceptions`. In that case, we can synthesize that
// binary op kind for the values in `exceptions`.
Optional<HoistParams> GetHoistParams(
TF::ConcatV2Op op, int64_t axis,
const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
};
LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
@ -1078,24 +1092,86 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
// on the channels dim for NCHW layout as axis=-2.
if (axis < 0) return failure();
// All concat operands must be defined by ops.
// All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
// or some other ops that we might convert to using the same op kind above
// (e.g. converting op A to tf.Mul(A, 1.0))
// TODO(hongm): generalize the code here to support cases where the first arg
// has no defining op (e.g. might be a block arg).
Operation *first_arg_op = op.values().front().getDefiningOp();
if (first_arg_op == nullptr) return failure();
// All concat operands must be produced by the coeff-wise binary operation.
if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
// All concat operands must be defined by the op of same kind.
bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
// All concat operands must be defined by the op of same kind, except for a
// minor portion which we track in `exceptions`.
// Map from the operands to operand indices.
llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
unsigned operand_idx = 0;
for (Value arg : op.values()) {
Operation *arg_op = arg.getDefiningOp();
return arg_op && arg_op->getName() == first_arg_op->getName();
});
if (!args_same_op) return failure();
if (arg_op && arg_op->getName() == first_arg_op->getName()) {
++operand_idx;
continue;
}
exceptions[arg] = operand_idx++;
}
// Recall those inputs to the concat op that are not produced by a binary op
// of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
// there are too many exceptions, it might not be cost effective to apply the
// concat hoisting optimization here.
// Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
// out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
// out of 3, we do.
const float exception_pct_threshold = 0.5;
if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
exceptions.size())
return failure();
// Compute binary operands hoist parameters.
auto hoist_params = GetHoistParams(op, axis);
auto hoist_params = GetHoistParams(op, axis, exceptions);
if (!hoist_params.hasValue()) return failure();
// Process `exceptions`: For each value there, synthesize a binary op of the
// above kind, so that the concat hoisting optimization can still apply.
if (!exceptions.empty()) {
int identity_val;
if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
identity_val = 0;
else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
isa<RealDivOp>(first_arg_op))
identity_val = 1;
else
return failure();
DenseElementsAttr const_attr;
auto scalar_tensor_type =
first_arg_op->getOperand(hoist_params->scalar_operand_idx)
.getType()
.dyn_cast<ShapedType>();
Type scalar_dtype = scalar_tensor_type.getElementType();
if (scalar_dtype.isa<FloatType>())
const_attr = DenseElementsAttr::get(scalar_tensor_type,
static_cast<float>(identity_val));
else
return failure();
// All checks are passes, and we now prepare for rewrite.
auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
for (const auto &kv : exceptions) {
assert(!hoist_params->lhs_args[kv.second]);
assert(!hoist_params->rhs_args[kv.second]);
if (hoist_params->scalar_operand_idx == 1) {
hoist_params->lhs_args[kv.second] = kv.first;
hoist_params->rhs_args[kv.second] = identity_const;
} else {
assert(hoist_params->scalar_operand_idx == 0);
hoist_params->lhs_args[kv.second] = identity_const;
hoist_params->rhs_args[kv.second] = kv.first;
}
}
}
// New lhs and rhs concatenation axis.
auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
auto lhs_axis = rewriter.create<TF::ConstOp>(
@ -1122,12 +1198,14 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
}
Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
int64_t axis) const {
HoistCwiseBinaryOutOfConcat::GetHoistParams(
TF::ConcatV2Op op, int64_t axis,
const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
assert(axis >= 0);
// Collects lhs or rhs arguments of concat op operands.
auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
auto range = llvm::map_range(op.values(), [&](Value arg) {
if (exceptions.count(arg)) return Value();
return arg.getDefiningOp()->getOperand(operand_idx);
});
return {range.begin(), range.end()};
@ -1137,6 +1215,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
// of `axis + 1` rank and axis dim has size `1`.
auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
return llvm::all_of(op.values(), [&](Value arg) -> bool {
if (exceptions.count(arg)) return true;
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
return ranked && ranked.getRank() == (axis + 1) &&
@ -1147,6 +1226,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
// Returns true if all binary ops operands at `operand_idx` index are scalars.
auto is_all_scalars = [&](int operand_idx) -> bool {
return llvm::all_of(op.values(), [&](Value arg) -> bool {
if (exceptions.count(arg)) return true;
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
return ranked && ranked.hasRank() && ranked.getRank() == 0;
@ -1168,13 +1248,24 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
if (is_all_tensors(0, axis) && is_all_scalars(1)) {
std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
return HoistParams{args(0), args(1), axis, 0, op.getType(), rhs_type};
return HoistParams{args(0),
args(1),
axis,
0,
op.getType(),
rhs_type,
/*scalar_operand_idx=*/1};
} else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
return HoistParams{args(0), args(1), 0, axis, lhs_type, op.getType()};
return HoistParams{args(0),
args(1),
0,
axis,
lhs_type,
op.getType(),
/*scalar_operand_idx=*/0};
}
return None;
}

View File

@ -226,6 +226,50 @@ func @testConcatCwiseBinaryNegativeAxis(%arg0: tensor<f32>,
return %3 : tensor<2xf32>
}
// Synthesize binary ops when 1 of the 3 concat inputs is a non-binary op.
// CHECK-LABEL: testConcatCwiseBinarySynthMulOp3Inputs
func @testConcatCwiseBinarySynthMulOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 1.000000e+00]>
// CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
// CHECK: "tf.Mul"(%[[CONCAT]], %[[CONST]])
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%mul1_const = "tf.Const"() { value = dense<3.0> : tensor<f32> } : () -> tensor<f32>
%mul1 = "tf.Mul"(%arg1, %mul1_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x3xf32>
return %ret : tensor<?x3xf32>
}
// Similar to to the above, with tf.Sub as the binary op kind.
func @testConcatCwiseBinarySynthSubOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 0.000000e+00]>
// CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
// CHECK: "tf.Sub"(%[[CONCAT]], %[[CONST]])
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%mul0 = "tf.Sub"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%mul1_const = "tf.Const"() { value = dense<3.0> : tensor<f32> } : () -> tensor<f32>
%mul1 = "tf.Sub"(%arg1, %mul1_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x3xf32>
return %ret : tensor<?x3xf32>
}
// Do not synthesize binary ops when 1 of the 2 concat inputs is a non-binary op.
// CHECK-LABEL: testConcatCwiseBinarySynthMulOp2Inputs
func @testConcatCwiseBinarySynthMulOp2Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>) -> tensor<?x2xf32> {
// CHECK: %[[MUL:.*]] = "tf.Mul"(%arg0,
// CHECK: "tf.ConcatV2"(%[[MUL]], %arg1,
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%ret = "tf.ConcatV2"(%mul0, %arg1, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x2xf32>
return %ret : tensor<?x2xf32>
}
// CHECK-LABEL: testLogOfSoftmax
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>