* Switch TFlite legalize to use ConversionTarget that only allows valid TFLite ops to be converted.

* Switch from erase on op to use rewriter erase.

PiperOrigin-RevId: 300591035
Change-Id: I78c91b725e9d8da456d880eb82fcb83a93bdbf3a
This commit is contained in:
Karim Nosir 2020-03-12 11:42:27 -07:00 committed by TensorFlower Gardener
parent ebda90f26b
commit 32b8b3c30b
3 changed files with 45 additions and 18 deletions

View File

@ -40,6 +40,8 @@ class TensorFlowLiteDialect : public Dialect {
public:
explicit TensorFlowLiteDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "tfl"; }
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,

View File

@ -41,9 +41,9 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor<?x10xf32>) -> i
%4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
return %4 : i32
// CHECK-LABEL: squeezeAndReshape
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
// CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
// CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor<?x10xf32>) -> tensor<*xf32>
// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32>
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32>
// CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32
// CHECK: return
@ -74,7 +74,7 @@ func @dynamicReshapeI64Fold(%arg0: tensor<*xf32>) -> tensor<1x2xf32> {
return %0 : tensor<1x2xf32>
// CHECK-LABEL: dynamicReshapeI64Fold
// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32>
// CHECK: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32>
// CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%arg0, %[[cst]]) : (tensor<*xf32>, tensor<2xi32>) -> tensor<1x2xf32>
// CHECK-NEXT: return %[[reshape]] : tensor<1x2xf32>
}
@ -114,10 +114,10 @@ func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
return %0 : tensor<8x16xf32>
// CHECK-LABEL: softplus
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
// CHECK: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
@ -676,8 +676,9 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
}
@ -709,8 +710,9 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
}
@ -1273,7 +1275,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32>
// CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32>
// CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32>

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -465,8 +466,7 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
// TF Lite doesn't support Assert, we just drop the assert from the graph.
PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
op->dropAllReferences();
op->erase();
rewriter.eraseOp(op);
return matchSuccess();
}
@ -653,7 +653,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
// Rewire the output.
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
op->erase();
rewriter.eraseOp(op);
return matchSuccess();
}
};
@ -712,7 +712,7 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
// Rewire the output.
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
op->erase();
rewriter.eraseOp(op);
return matchSuccess();
}
@ -720,22 +720,44 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* ctx = &getContext();
auto* context = &getContext();
auto func = getFunction();
// Add the generated patterns to the list.
populateWithGenerated(ctx, &patterns);
populateWithGenerated(context, &patterns);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(ctx);
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,
LegalizeUnidirectionalSequenceRnn>(ctx);
applyPatternsGreedily(func, patterns);
LegalizeUnidirectionalSequenceRnn>(context);
ConversionTarget target(*context);
// It is legal to have TF ops in the graph still which can be
// used later or in the case of SELECT were we allow TF ops in the final
// graph.
target.addLegalOp<mlir::ConstantOp>();
target.addLegalOp<ConstOp>();
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation()));
}));
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
// Currently unit-test doesn't do multiple tries, so we need this.
const int max_iterations = 15;
for (int i = 0; i < max_iterations; ++i) {
if (failed(applyPartialConversion(func, target, patterns))) {
return;
}
}
}
} // namespace