Put the pass for converting TF control flow regions to functional before
optimizing global tensors This transforms WhileRegion to a functional control flow since TFLite passes do not deal with WhileRegion as yet. PiperOrigin-RevId: 339552976 Change-Id: I767f9308a686d98a8e5cd553019c1e7b5b602230
This commit is contained in:
parent
8c2a9ffb64
commit
f9818f12c3
@ -406,16 +406,16 @@ versions {
|
||||
min_consumer: 12
|
||||
}
|
||||
|
||||
# CHECK: func @StatefulIf_else
|
||||
# CHECK: func @main
|
||||
# CHECK-NEXT: constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]>
|
||||
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
|
||||
# CHECK-NEXT: tfl.mul
|
||||
# CHECK: func @StatefulIf_then
|
||||
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
# CHECK-NEXT: return
|
||||
# CHECK: func @StatelessIf_else
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @cond_false_10{{.+}}is_stateless = true{{.+}}then_branch = @cond_true_10
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @cond_false0{{.+}}is_stateless = false{{.+}}then_branch = @cond_true0
|
||||
# CHECK: func @cond_false_10
|
||||
# CHECK-NEXT: tfl.div
|
||||
# CHECK: func @StatelessIf_then
|
||||
# CHECK: func @cond_true_10
|
||||
# CHECK-NEXT: tfl.sub
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then
|
||||
|
||||
# CHECK: func @cond_false0
|
||||
# CHECK-NEXT: tfl.mul
|
||||
# CHECK: func @cond_true0
|
||||
# CHECK-NEXT: tfl.add
|
||||
|
@ -121,6 +121,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
}
|
||||
|
||||
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
|
||||
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
@ -139,8 +141,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
|
||||
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
|
||||
|
||||
// Legalize while early to allow further constant folding.
|
||||
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
||||
// after the legalize below, for now it needs to be below the above passes
|
||||
|
@ -888,6 +888,31 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNonStatefulConvLSTM2D(self):
|
||||
"""Test saved model with non stateful ConvLSTM2D keras layer."""
|
||||
# Create keras model
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.ConvLSTM2D(
|
||||
32, (3, 3),
|
||||
padding='same',
|
||||
return_sequences=True,
|
||||
stateful=False,
|
||||
batch_input_shape=(1, 1, 10, 10, 1))
|
||||
])
|
||||
model.compile()
|
||||
|
||||
# Export the keras model to saved model.
|
||||
saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_lstm_2d')
|
||||
model.save(saved_model_dir, save_format='tf', include_optimizer=False)
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
converter.target_spec.supported_ops = [
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
|
||||
class FromKerasModelTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user