Put the pass for converting TF control flow regions to functional before

optimizing global tensors

This transforms WhileRegion to a functional control flow since TFLite passes do not deal with WhileRegion as yet.

PiperOrigin-RevId: 339552976
Change-Id: I767f9308a686d98a8e5cd553019c1e7b5b602230
This commit is contained in:
Jaesung Chung 2020-10-28 15:42:06 -07:00 committed by TensorFlower Gardener
parent 8c2a9ffb64
commit f9818f12c3
3 changed files with 37 additions and 12 deletions

View File

@ -406,16 +406,16 @@ versions {
min_consumer: 12
}
# CHECK: func @StatefulIf_else
# CHECK: func @main
# CHECK-NEXT: constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]>
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
# CHECK-NEXT: tfl.mul
# CHECK: func @StatefulIf_then
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
# CHECK-NEXT: return
# CHECK: func @StatelessIf_else
# CHECK: "tf.If"{{.+}}else_branch = @cond_false_10{{.+}}is_stateless = true{{.+}}then_branch = @cond_true_10
# CHECK: "tf.If"{{.+}}else_branch = @cond_false0{{.+}}is_stateless = false{{.+}}then_branch = @cond_true0
# CHECK: func @cond_false_10
# CHECK-NEXT: tfl.div
# CHECK: func @StatelessIf_then
# CHECK: func @cond_true_10
# CHECK-NEXT: tfl.sub
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then
# CHECK: func @cond_false0
# CHECK-NEXT: tfl.mul
# CHECK: func @cond_true0
# CHECK-NEXT: tfl.add

View File

@ -121,6 +121,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
@ -139,8 +141,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes

View File

@ -888,6 +888,31 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
self.assertTrue(tflite_model)
@test_util.run_v2_only
def testNonStatefulConvLSTM2D(self):
"""Test saved model with non stateful ConvLSTM2D keras layer."""
# Create keras model
model = tf.keras.Sequential([
tf.keras.layers.ConvLSTM2D(
32, (3, 3),
padding='same',
return_sequences=True,
stateful=False,
batch_input_shape=(1, 1, 10, 10, 1))
])
model.compile()
# Export the keras model to saved model.
saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_lstm_2d')
model.save(saved_model_dir, save_format='tf', include_optimizer=False)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
self.assertTrue(tflite_model)
class FromKerasModelTest(lite_v2_test_util.ModelTest):