Enhanced the HoistCwiseBinaryOutOfConcat rewrite pattern to synthesize input binary ops when needed.
Where a majority fraction of the inputs to a ConcatV2 op are produced by a binary op kind B (e.g. tf.Mul), for the remaining inputs, we now synthesize tf.Mul for each of them respectively, with the identity scalar const tensor 1.0. This allows us to apply HoistCwiseBinaryOutOfConcat to a wide range of workloads. PiperOrigin-RevId: 356821582 Change-Id: I252ba9abc3248953a88be978c98c3a3d4b55b089
This commit is contained in:
parent
7e60a5de30
commit
78d28301bc
@ -1039,6 +1039,14 @@ LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
|
|||||||
// %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
|
// %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
|
||||||
// %2 = tf.Mul(%0, %1)
|
// %2 = tf.Mul(%0, %1)
|
||||||
//
|
//
|
||||||
|
// If a minor fraction of the Concat inputs are not of the same binary op kind
|
||||||
|
// (tf.Mul in the above example), we will synthesize the binary ops for those
|
||||||
|
// inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
|
||||||
|
// tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
|
||||||
|
// float32 tensors.
|
||||||
|
// TODO(hongm): Implement this op synthesis optimization for other dtypes if
|
||||||
|
// needed.
|
||||||
|
//
|
||||||
// Because coefficient-wise binary operations support implicit broadcasting, we
|
// Because coefficient-wise binary operations support implicit broadcasting, we
|
||||||
// should be very careful with this optimization, and do not accidentally
|
// should be very careful with this optimization, and do not accidentally
|
||||||
// produce incorrect concat operations.
|
// produce incorrect concat operations.
|
||||||
@ -1057,11 +1065,17 @@ class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
|
|||||||
int64_t rhs_axis;
|
int64_t rhs_axis;
|
||||||
Type lhs_concat_type;
|
Type lhs_concat_type;
|
||||||
Type rhs_concat_type;
|
Type rhs_concat_type;
|
||||||
|
int scalar_operand_idx; // can be 0 or 1 for the binary op's operands.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns parameters of a binary op hoisting out of concatenation if all of
|
// Returns parameters of a binary op hoisting out of concatenation if all of
|
||||||
// the operands are in one of the compatible configurations.
|
// the operands are in one of the compatible configurations.
|
||||||
Optional<HoistParams> GetHoistParams(TF::ConcatV2Op op, int64_t axis) const;
|
// All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
|
||||||
|
// except from the ones in `exceptions`. In that case, we can synthesize that
|
||||||
|
// binary op kind for the values in `exceptions`.
|
||||||
|
Optional<HoistParams> GetHoistParams(
|
||||||
|
TF::ConcatV2Op op, int64_t axis,
|
||||||
|
const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
||||||
@ -1078,24 +1092,86 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
|||||||
// on the channels dim for NCHW layout as axis=-2.
|
// on the channels dim for NCHW layout as axis=-2.
|
||||||
if (axis < 0) return failure();
|
if (axis < 0) return failure();
|
||||||
|
|
||||||
// All concat operands must be defined by ops.
|
// All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
|
||||||
|
// or some other ops that we might convert to using the same op kind above
|
||||||
|
// (e.g. converting op A to tf.Mul(A, 1.0))
|
||||||
|
// TODO(hongm): generalize the code here to support cases where the first arg
|
||||||
|
// has no defining op (e.g. might be a block arg).
|
||||||
Operation *first_arg_op = op.values().front().getDefiningOp();
|
Operation *first_arg_op = op.values().front().getDefiningOp();
|
||||||
if (first_arg_op == nullptr) return failure();
|
if (first_arg_op == nullptr) return failure();
|
||||||
|
|
||||||
// All concat operands must be produced by the coeff-wise binary operation.
|
// All concat operands must be produced by the coeff-wise binary operation.
|
||||||
if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
|
if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
|
||||||
|
|
||||||
// All concat operands must be defined by the op of same kind.
|
// All concat operands must be defined by the op of same kind, except for a
|
||||||
bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
|
// minor portion which we track in `exceptions`.
|
||||||
|
// Map from the operands to operand indices.
|
||||||
|
llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
|
||||||
|
unsigned operand_idx = 0;
|
||||||
|
for (Value arg : op.values()) {
|
||||||
Operation *arg_op = arg.getDefiningOp();
|
Operation *arg_op = arg.getDefiningOp();
|
||||||
return arg_op && arg_op->getName() == first_arg_op->getName();
|
if (arg_op && arg_op->getName() == first_arg_op->getName()) {
|
||||||
});
|
++operand_idx;
|
||||||
if (!args_same_op) return failure();
|
continue;
|
||||||
|
}
|
||||||
|
exceptions[arg] = operand_idx++;
|
||||||
|
}
|
||||||
|
// Recall those inputs to the concat op that are not produced by a binary op
|
||||||
|
// of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
|
||||||
|
// there are too many exceptions, it might not be cost effective to apply the
|
||||||
|
// concat hoisting optimization here.
|
||||||
|
// Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
|
||||||
|
// out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
|
||||||
|
// out of 3, we do.
|
||||||
|
const float exception_pct_threshold = 0.5;
|
||||||
|
if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
|
||||||
|
exceptions.size())
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Compute binary operands hoist parameters.
|
// Compute binary operands hoist parameters.
|
||||||
auto hoist_params = GetHoistParams(op, axis);
|
auto hoist_params = GetHoistParams(op, axis, exceptions);
|
||||||
if (!hoist_params.hasValue()) return failure();
|
if (!hoist_params.hasValue()) return failure();
|
||||||
|
|
||||||
|
// Process `exceptions`: For each value there, synthesize a binary op of the
|
||||||
|
// above kind, so that the concat hoisting optimization can still apply.
|
||||||
|
if (!exceptions.empty()) {
|
||||||
|
int identity_val;
|
||||||
|
if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
|
||||||
|
identity_val = 0;
|
||||||
|
else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
|
||||||
|
isa<RealDivOp>(first_arg_op))
|
||||||
|
identity_val = 1;
|
||||||
|
else
|
||||||
|
return failure();
|
||||||
|
DenseElementsAttr const_attr;
|
||||||
|
auto scalar_tensor_type =
|
||||||
|
first_arg_op->getOperand(hoist_params->scalar_operand_idx)
|
||||||
|
.getType()
|
||||||
|
.dyn_cast<ShapedType>();
|
||||||
|
Type scalar_dtype = scalar_tensor_type.getElementType();
|
||||||
|
if (scalar_dtype.isa<FloatType>())
|
||||||
|
const_attr = DenseElementsAttr::get(scalar_tensor_type,
|
||||||
|
static_cast<float>(identity_val));
|
||||||
|
else
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// All checks are passes, and we now prepare for rewrite.
|
||||||
|
auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
|
||||||
|
for (const auto &kv : exceptions) {
|
||||||
|
assert(!hoist_params->lhs_args[kv.second]);
|
||||||
|
assert(!hoist_params->rhs_args[kv.second]);
|
||||||
|
|
||||||
|
if (hoist_params->scalar_operand_idx == 1) {
|
||||||
|
hoist_params->lhs_args[kv.second] = kv.first;
|
||||||
|
hoist_params->rhs_args[kv.second] = identity_const;
|
||||||
|
} else {
|
||||||
|
assert(hoist_params->scalar_operand_idx == 0);
|
||||||
|
hoist_params->lhs_args[kv.second] = identity_const;
|
||||||
|
hoist_params->rhs_args[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// New lhs and rhs concatenation axis.
|
// New lhs and rhs concatenation axis.
|
||||||
auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
|
auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
|
||||||
auto lhs_axis = rewriter.create<TF::ConstOp>(
|
auto lhs_axis = rewriter.create<TF::ConstOp>(
|
||||||
@ -1122,12 +1198,14 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
|
Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
|
||||||
HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
HoistCwiseBinaryOutOfConcat::GetHoistParams(
|
||||||
int64_t axis) const {
|
TF::ConcatV2Op op, int64_t axis,
|
||||||
|
const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
|
||||||
assert(axis >= 0);
|
assert(axis >= 0);
|
||||||
// Collects lhs or rhs arguments of concat op operands.
|
// Collects lhs or rhs arguments of concat op operands.
|
||||||
auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
|
auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
|
||||||
auto range = llvm::map_range(op.values(), [&](Value arg) {
|
auto range = llvm::map_range(op.values(), [&](Value arg) {
|
||||||
|
if (exceptions.count(arg)) return Value();
|
||||||
return arg.getDefiningOp()->getOperand(operand_idx);
|
return arg.getDefiningOp()->getOperand(operand_idx);
|
||||||
});
|
});
|
||||||
return {range.begin(), range.end()};
|
return {range.begin(), range.end()};
|
||||||
@ -1137,6 +1215,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
|||||||
// of `axis + 1` rank and axis dim has size `1`.
|
// of `axis + 1` rank and axis dim has size `1`.
|
||||||
auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
|
auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
|
||||||
return llvm::all_of(op.values(), [&](Value arg) -> bool {
|
return llvm::all_of(op.values(), [&](Value arg) -> bool {
|
||||||
|
if (exceptions.count(arg)) return true;
|
||||||
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
|
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
|
||||||
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
|
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
return ranked && ranked.getRank() == (axis + 1) &&
|
return ranked && ranked.getRank() == (axis + 1) &&
|
||||||
@ -1147,6 +1226,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
|||||||
// Returns true if all binary ops operands at `operand_idx` index are scalars.
|
// Returns true if all binary ops operands at `operand_idx` index are scalars.
|
||||||
auto is_all_scalars = [&](int operand_idx) -> bool {
|
auto is_all_scalars = [&](int operand_idx) -> bool {
|
||||||
return llvm::all_of(op.values(), [&](Value arg) -> bool {
|
return llvm::all_of(op.values(), [&](Value arg) -> bool {
|
||||||
|
if (exceptions.count(arg)) return true;
|
||||||
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
|
auto operand = arg.getDefiningOp()->getOperand(operand_idx);
|
||||||
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
|
auto ranked = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
return ranked && ranked.hasRank() && ranked.getRank() == 0;
|
return ranked && ranked.hasRank() && ranked.getRank() == 0;
|
||||||
@ -1168,13 +1248,24 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op,
|
|||||||
if (is_all_tensors(0, axis) && is_all_scalars(1)) {
|
if (is_all_tensors(0, axis) && is_all_scalars(1)) {
|
||||||
std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
|
std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
|
||||||
auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
|
auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
|
||||||
return HoistParams{args(0), args(1), axis, 0, op.getType(), rhs_type};
|
return HoistParams{args(0),
|
||||||
|
args(1),
|
||||||
|
axis,
|
||||||
|
0,
|
||||||
|
op.getType(),
|
||||||
|
rhs_type,
|
||||||
|
/*scalar_operand_idx=*/1};
|
||||||
} else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
|
} else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
|
||||||
std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
|
std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
|
||||||
auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
|
auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
|
||||||
return HoistParams{args(0), args(1), 0, axis, lhs_type, op.getType()};
|
return HoistParams{args(0),
|
||||||
|
args(1),
|
||||||
|
0,
|
||||||
|
axis,
|
||||||
|
lhs_type,
|
||||||
|
op.getType(),
|
||||||
|
/*scalar_operand_idx=*/0};
|
||||||
}
|
}
|
||||||
|
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,6 +226,50 @@ func @testConcatCwiseBinaryNegativeAxis(%arg0: tensor<f32>,
|
|||||||
return %3 : tensor<2xf32>
|
return %3 : tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Synthesize binary ops when 1 of the 3 concat inputs is a non-binary op.
|
||||||
|
// CHECK-LABEL: testConcatCwiseBinarySynthMulOp3Inputs
|
||||||
|
func @testConcatCwiseBinarySynthMulOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
|
||||||
|
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 1.000000e+00]>
|
||||||
|
// CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
|
||||||
|
// CHECK: "tf.Mul"(%[[CONCAT]], %[[CONST]])
|
||||||
|
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||||
|
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||||
|
%mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||||
|
%mul1_const = "tf.Const"() { value = dense<3.0> : tensor<f32> } : () -> tensor<f32>
|
||||||
|
%mul1 = "tf.Mul"(%arg1, %mul1_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||||
|
%ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x3xf32>
|
||||||
|
|
||||||
|
return %ret : tensor<?x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to to the above, with tf.Sub as the binary op kind.
|
||||||
|
func @testConcatCwiseBinarySynthSubOp3Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x3xf32> {
|
||||||
|
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[2.000000e+00, 3.000000e+00, 0.000000e+00]>
|
||||||
|
// CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2,
|
||||||
|
// CHECK: "tf.Sub"(%[[CONCAT]], %[[CONST]])
|
||||||
|
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||||
|
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||||
|
%mul0 = "tf.Sub"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||||
|
%mul1_const = "tf.Const"() { value = dense<3.0> : tensor<f32> } : () -> tensor<f32>
|
||||||
|
%mul1 = "tf.Sub"(%arg1, %mul1_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||||
|
%ret = "tf.ConcatV2"(%mul0, %mul1, %arg2, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x3xf32>
|
||||||
|
|
||||||
|
return %ret : tensor<?x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not synthesize binary ops when 1 of the 2 concat inputs is a non-binary op.
|
||||||
|
// CHECK-LABEL: testConcatCwiseBinarySynthMulOp2Inputs
|
||||||
|
func @testConcatCwiseBinarySynthMulOp2Inputs(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>) -> tensor<?x2xf32> {
|
||||||
|
// CHECK: %[[MUL:.*]] = "tf.Mul"(%arg0,
|
||||||
|
// CHECK: "tf.ConcatV2"(%[[MUL]], %arg1,
|
||||||
|
%axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||||
|
%mul0_const = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||||
|
%mul0 = "tf.Mul"(%arg0, %mul0_const) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||||
|
%ret = "tf.ConcatV2"(%mul0, %arg1, %axis) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<i32>) -> tensor<?x2xf32>
|
||||||
|
|
||||||
|
return %ret : tensor<?x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: testLogOfSoftmax
|
// CHECK-LABEL: testLogOfSoftmax
|
||||||
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user