Enhanced the HoistCwiseBinaryOutOfConcat rewrite pattern to synthesize input binary ops when needed.
Where a majority fraction of the inputs to a ConcatV2 op are produced by a binary op kind B (e.g. tf.Mul), for the remaining inputs, we now synthesize tf.Mul for each of them respectively, with the identity scalar const tensor 1.0. This allows us to apply HoistCwiseBinaryOutOfConcat to a wide range of workloads. PiperOrigin-RevId: 356821582 Change-Id: I252ba9abc3248953a88be978c98c3a3d4b55b089
This commit is contained in:
parent
7e60a5de30
commit
78d28301bc
@ -1039,6 +1039,14 @@ LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
|
||||
// %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
|
||||
// %2 = tf.Mul(%0, %1)
|
||||
//
|
||||
// If a minor fraction of the Concat inputs are not of the same binary op kind
|
||||
// (tf.Mul in the above example), we will synthesize the binary ops for those
|
||||
// inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
|
||||
// tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
|
||||
// float32 tensors.
|
||||
// TODO(hongm): Implement this op synthesis optimization for other dtypes if
|
||||
// needed.
|
||||
//
|
||||
// Because coefficient-wise binary operations support implicit broadcasting, we
|
||||
// should be very careful with this optimization, and do not accidentally
|
||||
// produce incorrect concat operations.
|
||||
@ -1057,11 +1065,17 @@ class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
|
||||
int64_t rhs_axis;
|
||||
Type lhs_concat_type;
|
||||
Type rhs_concat_type;
|
||||
int scalar_operand_idx; // can be 0 or 1 for the binary op's operands.
|
||||
};
|
||||
|
||||
// Returns parameters of a binary op hoisting out of concatenation if all of
|
||||
// the operands are in one of the compatible configurations.
|
||||
Optional<HoistParams> GetHoistParams(TF::ConcatV2Op op, int64_t axis) const;
|
||||
// All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
|
||||
// except from the ones in `exceptions`. In that case, we can synthesize that
|
||||
// binary op kind for the values in `exceptions`.
|
||||
Optional<HoistParams> GetHoistParams(
|
||||
TF::ConcatV2Op op, int64_t axis,
|
||||
const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
|
||||
};
|
||||
|
||||
LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
||||
@ -1078,24 +1092,86 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
||||
// on the channels dim for NCHW layout as axis=-2.
|
||||
if (axis < 0) return failure();
|
||||
|
||||
// All concat operands must be defined by ops.
|
||||
// All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
|
||||
// or some other ops that we might convert to using the same op kind above
|
||||
// (e.g. converting op A to tf.Mul(A, 1.0))
|
||||
// TODO(hongm): generalize the code here to support cases where the first arg
|
||||
// has no defining op (e.g. might be a block arg).
|
||||
Operation *first_arg_op = op.values().front().getDefiningOp();
|
||||
if (first_arg_op == nullptr) return failure();
|
||||
|
||||
// All concat operands must be produced by the coeff-wise binary operation.
|
||||
if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
|
||||
|
||||
// All concat operands must be defined by the op of same kind.
|
||||
bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
|
||||
// All concat operands must be defined by the op of same kind, except for a
|
||||
// minor portion which we track in `exceptions`.
|
||||
// Map from the operands to operand indices.
|
||||
llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
|
||||
unsigned operand_idx = 0;
|
||||
for (Value arg : op.values()) {
|
||||
Operation *arg_op = arg.getDefiningOp();
|
||||
return arg_op && arg_op->getName() == first_arg_op->getName();
|
||||
});
|
||||
if (!args_same_op) return failure();
|
||||
if (arg_op && arg_op->getName() == first_arg_op->getName()) {
|
||||
++operand_idx;
|
||||
continue;
|
||||
}
|
||||
exceptions[arg] = operand_idx++;
|
||||
}
|
||||
// Recall those inputs to the concat op that are not produced by a binary op
|
||||
// of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
|
||||
// there are too many exceptions, it might not be cost effective to apply the
|
||||
// concat hoisting optimization here.
|
||||
// Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
|
||||
// out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
|
||||
// out of 3, we do.
|
||||
const float exception_pct_threshold = 0.5;
|
||||
if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
|
||||
exceptions.size())
|
||||
return failure();
|
||||
|
||||
// Compute binary operands hoist parameters.
|
||||
auto hoist_params = GetHoistParams(op, axis);
|
||||
auto hoist_params = GetHoistParams(op, axis, exceptions);
|
||||
if (!hoist_params.hasValue()) return failure();
|
||||
|
||||
// Process `exceptions`: For each value there, synthesize a binary op of the
|
||||
// above kind, so that the concat hoisting optimization can still apply.
|
||||
if (!exceptions.empty()) {
|
||||
int identity_val;
|
||||
if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
|
||||
identity_val = 0;
|
||||
else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
|
||||
isa<RealDivOp>(first_arg_op))
|
||||
identity_val = 1;
|
||||
else
|
||||
return failure();
|
||||
DenseElementsAttr const_attr;
|
||||
auto scalar_tensor_type =
|
||||
first_arg_op->getOperand(hoist_params->scalar_operand_idx)
|
||||
.getType()
|
||||
.dyn_cast<ShapedType>();
|
||||
Type scalar_dtype = scalar_tensor_type.getElementType();
|
||||
if (scalar_dtype.isa<FloatType>())
|
||||
const_attr = DenseElementsAttr::get(scalar_tensor_type,
|
||||
static_cast<float>(identity_val));
|
||||
else
|
||||
return failure();
|
||||
|
||||
// All checks are passes, and we now prepare for rewrite.
|
||||
auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
|
||||
for (const auto &kv : exceptions) {
|
||||
assert(!hoist_params->lhs_args[kv.second]);
|
||||
assert(!hoist_params->rhs_args[kv.second]);
|
||||
|
||||
if (hoist_params->scalar_operand_idx == 1) {
|
||||
hoist_params->lhs_args[kv.second] = kv.first;
|
||||
hoist_params->rhs_args[kv.second] = identity_const;
|
||||
} else {
|
||||
assert(hoist_params->scalar_operand_idx == 0);
|
||||
hoist_params->lhs_args[kv.second] = identity_const;
|
||||
hoist_params->rhs_args[kv.second] = kv.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New lhs and rhs concatenation axis.
|
||||
auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
|
||||
auto lhs_axis = rewriter.create<TF::ConstOp>(
|
||||
@ -1122,12 +1198,14 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
||||
}
|
||||
|
||||
Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
|
||||
HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
||||
int64_t axis) const {
|
||||
HoistCwiseBinaryOutOfConcat::GetHoistParams(
|
||||
TF::ConcatV2Op op, int64_t axis,
|
||||
const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
|
||||
assert(axis >= 0);
|
||||
// Collects lhs or rhs arguments of concat op operands.
|
||||
auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
|
||||
auto range = llvm::map_range(op.values(), [&](Value arg) {
|
||||
if (exceptions.count(arg)) return Value();
|
||||
return arg.getDefiningOp()->getOperand(operand_idx);
|
||||
});
|
||||
return {range.begin(), range.end()};
|
||||
@ -1137,6 +1215,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
||||
// of `axis + 1` rank and axis dim has size `1`.
|
||||
auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
|
||||
return llvm::all_of(op.values(), [&](Value arg) -> bool {
|
||||
if (exceptions.count(arg)) return true;
|
||||
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
|
||||
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
|
||||
return ranked && ranked.getRank() == (axis + 1) &&
|
||||
@ -1147,6 +1226,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
||||
// Returns true if all binary ops operands at `operand_idx` index are scalars.
|
||||
auto is_all_scalars = [&](int operand_idx) -> bool {
|
||||
return llvm::all_of(op.values(), [&](Value arg) -> bool {
|
||||
if (exceptions.count(arg)) return true;
|
||||
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
|
||||
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
|
||||
return ranked && ranked.hasRank() && ranked.getRank() == 0;
|
||||
@ -1168,13 +1248,24 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
||||
if (is_all_tensors(0, axis) && is_all_scalars(1)) {
|
||||
std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
|
||||
auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
|
||||
return HoistParams{args(0), args(1), axis, 0, op.getType(), rhs_type};
|
||||
return HoistParams{args(0),
|
||||
args(1),
|
||||
axis,
|
||||
0,
|
||||
op.getType(),
|
||||
rhs_type,
|
||||
/*scalar_operand_idx=*/1};
|
||||
} else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
|
||||
std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
|
||||
auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
|
||||
return HoistParams{args(0), args(1), 0, axis, lhs_type, op.getType()};
|
||||
return HoistParams{args(0),
|
||||
args(1),
|
||||
0,
|
||||
axis,
|
||||
lhs_type,
|
||||
op.getType(),
|
||||
/*scalar_operand_idx=*/0};
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
|
||||
|
@ -226,6 +226,50 @@ func @testConcatCwiseBinaryNegativeAxis(%arg0: tensor<f32>,
|
||||
return %3 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// Synthesize binary ops when 1 of the 3 concat inputs is a non-binary op.
|
||||
// CHECK-LABEL: testConcatCwiseBinarySynthMulOp3Inputs
|
||||
func @testConcatCwiseBinarySynthMulOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 1.000000e+00]>
|
||||
// CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
|
||||
// CHECK: "tf.Mul"(%[[CONCAT]], %[[CONST]])
|
||||
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%mul1_const = "tf.Const"() { value = dense<3.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%mul1 = "tf.Mul"(%arg1, %mul1_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x3xf32>
|
||||
|
||||
return %ret : tensor<?x3xf32>
|
||||
}
|
||||
|
||||
// Similar to to the above, with tf.Sub as the binary op kind.
|
||||
func @testConcatCwiseBinarySynthSubOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 0.000000e+00]>
|
||||
// CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
|
||||
// CHECK: "tf.Sub"(%[[CONCAT]], %[[CONST]])
|
||||
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%mul0 = "tf.Sub"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%mul1_const = "tf.Const"() { value = dense<3.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%mul1 = "tf.Sub"(%arg1, %mul1_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x3xf32>
|
||||
|
||||
return %ret : tensor<?x3xf32>
|
||||
}
|
||||
|
||||
// Do not synthesize binary ops when 1 of the 2 concat inputs is a non-binary op.
|
||||
// CHECK-LABEL: testConcatCwiseBinarySynthMulOp2Inputs
|
||||
func @testConcatCwiseBinarySynthMulOp2Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>) -> tensor<?x2xf32> {
|
||||
// CHECK: %[[MUL:.*]] = "tf.Mul"(%arg0,
|
||||
// CHECK: "tf.ConcatV2"(%[[MUL]], %arg1,
|
||||
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%ret = "tf.ConcatV2"(%mul0, %arg1, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x2xf32>
|
||||
|
||||
return %ret : tensor<?x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogOfSoftmax
|
||||
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user