[TF:MLIR] Fold constant cond in IfOp
PiperOrigin-RevId: 321499354 Change-Id: Ie644b7f25d1c05d43c43f90983c7f5ac00866c47
This commit is contained in:
parent
59215389ce
commit
9c20fbff6c
@ -175,6 +175,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
|
||||
// Inline function calls that left in the graph after folding functional
|
||||
// control flow ops (IfOp, CaseOp).
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
|
@ -33,7 +33,7 @@ def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
||||
// point constant.
|
||||
def : Pat<(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
|
||||
(ConstantOp $cst)>;
|
||||
(TFL_ConstOp $cst)>;
|
||||
|
||||
// Quantize the value of a constant op if the quantization parameters have been
|
||||
// propagated to the output.
|
||||
|
@ -229,6 +229,8 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_YieldOp : TF_Op<"Yield",
|
||||
|
@ -463,7 +463,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
|
||||
auto call_op = rewriter.create<PartitionedCallOp>(
|
||||
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||
PropagateAttributes(op.getOperation(), call_op);
|
||||
PropagateDeviceAndInternalAttrs(op.getOperation(), call_op);
|
||||
rewriter.replaceOp(op, call_op.getResults());
|
||||
return success();
|
||||
}
|
||||
@ -1615,6 +1615,56 @@ static LogicalResult Verify(IfOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
public:
|
||||
explicit FoldConstantIfOp(MLIRContext *context)
|
||||
: OpRewritePattern<TF::IfOp>(context) {}
|
||||
LogicalResult matchAndRewrite(TF::IfOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct CallOpType {
|
||||
using CallOp = T;
|
||||
};
|
||||
};
|
||||
|
||||
LogicalResult FoldConstantIfOp::matchAndRewrite(
|
||||
TF::IfOp op, PatternRewriter &rewriter) const {
|
||||
// Extract the constant cond value.
|
||||
DenseIntElementsAttr cond_attr;
|
||||
if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
|
||||
|
||||
// Cond value must be a scalar.
|
||||
if (cond_attr.getNumElements() != 1) return failure();
|
||||
|
||||
// Select a branch function.
|
||||
bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
|
||||
FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
|
||||
|
||||
// Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
|
||||
auto rewrite = [&](auto op_type) {
|
||||
auto empty = rewriter.getStringAttr("");
|
||||
auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
|
||||
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||
PropagateDeviceAndInternalAttrs(op.getOperation(), call_op);
|
||||
rewriter.replaceOp(op, call_op.getResults());
|
||||
};
|
||||
|
||||
if (op.is_stateless())
|
||||
rewrite(CallOpType<PartitionedCallOp>{});
|
||||
else
|
||||
rewrite(CallOpType<StatefulPartitionedCallOp>{});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<FoldConstantIfOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfRegionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
// Propagates underscore and device attributes from src to dst.
|
||||
// TODO(b/158769932): This should be a general feature instead post some policy
|
||||
// discussion.
|
||||
static void PropagateAttributes(Operation *src, Operation *dst) {
|
||||
static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) {
|
||||
auto device = mlir::Identifier::get("device", src->getContext());
|
||||
for (auto named_attr : src->getAttrs()) {
|
||||
if (*named_attr.first.begin() == '_' || named_attr.first == device)
|
||||
|
@ -755,6 +755,27 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
||||
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: foldIf
|
||||
func @foldIf(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>) {
|
||||
%0 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
|
||||
// CHECK: %0 = "tf.PartitionedCall"(%arg0, %arg1)
|
||||
// CHECK-SAME: device = "noodle"
|
||||
// CHECK-SAME: f = @sub
|
||||
%2 = "tf.If"(%0, %arg0, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], device = "noodle", is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %1 = "tf.StatefulPartitionedCall"(%0, %arg1)
|
||||
// CHECK-SAME: _underscore_attr = "something"
|
||||
// CHECK-SAME: f = @add
|
||||
%3 = "tf.If"(%1, %2, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], _underscore_attr = "something", is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: %2 = "tf.If"
|
||||
%4 = "tf.If"(%arg2, %3, %arg1) {then_branch = @add, else_branch = @sub, is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: return %2
|
||||
return %4 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: foldCase
|
||||
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
%2 = constant dense<1> : tensor<i32>
|
||||
|
@ -33,9 +33,10 @@ from tensorflow.python.ops import control_flow_ops
|
||||
|
||||
def Test():
|
||||
data = tf.constant([1, 2, 3, 4, 5, 6])
|
||||
zero = tf.convert_to_tensor(0)
|
||||
one = tf.convert_to_tensor(1)
|
||||
less_op = tf.less(zero, one)
|
||||
# Create placeholders to prevent constant folding.
|
||||
x_op = tf.placeholder(dtype=tf.int32)
|
||||
y_op = tf.placeholder(dtype=tf.int32)
|
||||
less_op = tf.less(x_op, y_op)
|
||||
switch_op = control_flow_ops.switch(data, less_op)
|
||||
merge_op = control_flow_ops.merge(switch_op)[0]
|
||||
result = tf.transpose(merge_op)
|
||||
|
Loading…
Reference in New Issue
Block a user