[TF:MLIR] Fold constant cond in IfOp
PiperOrigin-RevId: 321499354 Change-Id: Ie644b7f25d1c05d43c43f90983c7f5ac00866c47
This commit is contained in:
parent
59215389ce
commit
9c20fbff6c
@ -175,6 +175,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Inline function calls that left in the graph after folding functional
|
||||||
|
// control flow ops (IfOp, CaseOp).
|
||||||
|
pass_manager->addPass(mlir::createInlinerPass());
|
||||||
|
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(
|
||||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||||
|
@ -33,7 +33,7 @@ def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
|
|||||||
// point constant.
|
// point constant.
|
||||||
def : Pat<(TFL_DequantizeOp
|
def : Pat<(TFL_DequantizeOp
|
||||||
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
|
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
|
||||||
(ConstantOp $cst)>;
|
(TFL_ConstOp $cst)>;
|
||||||
|
|
||||||
// Quantize the value of a constant op if the quantization parameters have been
|
// Quantize the value of a constant op if the quantization parameters have been
|
||||||
// propagated to the output.
|
// propagated to the output.
|
||||||
|
@ -229,6 +229,8 @@ else_branch: A function that takes 'inputs' and returns a list of
|
|||||||
let verifier = [{
|
let verifier = [{
|
||||||
return Verify(*this);
|
return Verify(*this);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_YieldOp : TF_Op<"Yield",
|
def TF_YieldOp : TF_Op<"Yield",
|
||||||
|
@ -463,7 +463,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
|
|||||||
auto call_op = rewriter.create<PartitionedCallOp>(
|
auto call_op = rewriter.create<PartitionedCallOp>(
|
||||||
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||||
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||||
PropagateAttributes(op.getOperation(), call_op);
|
PropagateDeviceAndInternalAttrs(op.getOperation(), call_op);
|
||||||
rewriter.replaceOp(op, call_op.getResults());
|
rewriter.replaceOp(op, call_op.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -1615,6 +1615,56 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
|
||||||
|
public:
|
||||||
|
explicit FoldConstantIfOp(MLIRContext *context)
|
||||||
|
: OpRewritePattern<TF::IfOp>(context) {}
|
||||||
|
LogicalResult matchAndRewrite(TF::IfOp op,
|
||||||
|
PatternRewriter &rewriter) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
struct CallOpType {
|
||||||
|
using CallOp = T;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
LogicalResult FoldConstantIfOp::matchAndRewrite(
|
||||||
|
TF::IfOp op, PatternRewriter &rewriter) const {
|
||||||
|
// Extract the constant cond value.
|
||||||
|
DenseIntElementsAttr cond_attr;
|
||||||
|
if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
|
||||||
|
|
||||||
|
// Cond value must be a scalar.
|
||||||
|
if (cond_attr.getNumElements() != 1) return failure();
|
||||||
|
|
||||||
|
// Select a branch function.
|
||||||
|
bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
|
||||||
|
FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
|
||||||
|
|
||||||
|
// Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
|
||||||
|
auto rewrite = [&](auto op_type) {
|
||||||
|
auto empty = rewriter.getStringAttr("");
|
||||||
|
auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
|
||||||
|
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||||
|
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||||
|
PropagateDeviceAndInternalAttrs(op.getOperation(), call_op);
|
||||||
|
rewriter.replaceOp(op, call_op.getResults());
|
||||||
|
};
|
||||||
|
|
||||||
|
if (op.is_stateless())
|
||||||
|
rewrite(CallOpType<PartitionedCallOp>{});
|
||||||
|
else
|
||||||
|
rewrite(CallOpType<StatefulPartitionedCallOp>{});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.insert<FoldConstantIfOp>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// IfRegionOp
|
// IfRegionOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
// Propagates underscore and device attributes from src to dst.
|
// Propagates underscore and device attributes from src to dst.
|
||||||
// TODO(b/158769932): This should be a general feature instead post some policy
|
// TODO(b/158769932): This should be a general feature instead post some policy
|
||||||
// discussion.
|
// discussion.
|
||||||
static void PropagateAttributes(Operation *src, Operation *dst) {
|
static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) {
|
||||||
auto device = mlir::Identifier::get("device", src->getContext());
|
auto device = mlir::Identifier::get("device", src->getContext());
|
||||||
for (auto named_attr : src->getAttrs()) {
|
for (auto named_attr : src->getAttrs()) {
|
||||||
if (*named_attr.first.begin() == '_' || named_attr.first == device)
|
if (*named_attr.first.begin() == '_' || named_attr.first == device)
|
||||||
|
@ -755,6 +755,27 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
|||||||
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
|
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: foldIf
|
||||||
|
func @foldIf(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>) {
|
||||||
|
%0 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||||
|
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||||
|
|
||||||
|
// CHECK: %0 = "tf.PartitionedCall"(%arg0, %arg1)
|
||||||
|
// CHECK-SAME: device = "noodle"
|
||||||
|
// CHECK-SAME: f = @sub
|
||||||
|
%2 = "tf.If"(%0, %arg0, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], device = "noodle", is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %1 = "tf.StatefulPartitionedCall"(%0, %arg1)
|
||||||
|
// CHECK-SAME: _underscore_attr = "something"
|
||||||
|
// CHECK-SAME: f = @add
|
||||||
|
%3 = "tf.If"(%1, %2, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], _underscore_attr = "something", is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
|
||||||
|
// CHECK: %2 = "tf.If"
|
||||||
|
%4 = "tf.If"(%arg2, %3, %arg1) {then_branch = @add, else_branch = @sub, is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
|
||||||
|
// CHECK: return %2
|
||||||
|
return %4 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: foldCase
|
// CHECK-LABEL: foldCase
|
||||||
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||||
%2 = constant dense<1> : tensor<i32>
|
%2 = constant dense<1> : tensor<i32>
|
||||||
|
@ -33,9 +33,10 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
|
|
||||||
def Test():
|
def Test():
|
||||||
data = tf.constant([1, 2, 3, 4, 5, 6])
|
data = tf.constant([1, 2, 3, 4, 5, 6])
|
||||||
zero = tf.convert_to_tensor(0)
|
# Create placeholders to prevent constant folding.
|
||||||
one = tf.convert_to_tensor(1)
|
x_op = tf.placeholder(dtype=tf.int32)
|
||||||
less_op = tf.less(zero, one)
|
y_op = tf.placeholder(dtype=tf.int32)
|
||||||
|
less_op = tf.less(x_op, y_op)
|
||||||
switch_op = control_flow_ops.switch(data, less_op)
|
switch_op = control_flow_ops.switch(data, less_op)
|
||||||
merge_op = control_flow_ops.merge(switch_op)[0]
|
merge_op = control_flow_ops.merge(switch_op)[0]
|
||||||
result = tf.transpose(merge_op)
|
result = tf.transpose(merge_op)
|
||||||
|
Loading…
Reference in New Issue
Block a user