From 352a98d60de6f5ca223ec64958ce9c3047df3111 Mon Sep 17 00:00:00 2001
From: Jaesung Chung <jaesung@google.com>
Date: Tue, 23 Feb 2021 20:45:19 -0800
Subject: [PATCH] Implement Flex fallback for dynamic TensorList use cases

PiperOrigin-RevId: 359198873
Change-Id: Ia0e9bf54380d317d22b13d2394e784d0099d4be0
---
 tensorflow/compiler/mlir/lite/BUILD           |  1 +
 .../lite/python/tf_tfl_flatbuffer_helpers.cc  |  4 +-
 .../compiler/mlir/lite/tf_tfl_passes.cc       | 12 +++--
 tensorflow/compiler/mlir/lite/tf_tfl_passes.h |  2 +
 .../transforms/lower_static_tensor_list.cc    | 10 ++--
 .../compiler/mlir/lite/transforms/passes.h    |  3 +-
 tensorflow/lite/python/lite_v2_test.py        | 47 +++++++++++++++++++
 7 files changed, 70 insertions(+), 9 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 2cd581b3a22..20ff771405e 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -976,6 +976,7 @@ cc_library(
         "//tensorflow/compiler/mlir/tensorflow:translate_lib",
         "//tensorflow/core:core_cpu_base",
         "//tensorflow/lite/toco:model_flags_proto_cc",
+        "//tensorflow/lite/toco:toco_flags_proto_cc",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index edde83c046d..90c25d92c74 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -311,8 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
                        mlir::OpPassManager::Nesting::Implicit);
   ::tensorflow::SetCrashReproducer(pm);
 
-  tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm,
-                                         session);
+  tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
+                                         &pm, session);
   // Convert back to outlined while format for export back to flatbuffer.
   if (pass_config.legalize_tf_while) {
     pm.addPass(mlir::TFL::CreateWhileOutlinePass());
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index 32ff102ca29..14e4b495855 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -63,6 +63,7 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
 }
 
 void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
+                                const toco::TocoFlags& toco_flags,
                                 const mlir::TFL::PassConfig& pass_config,
                                 mlir::OpPassManager* pass_manager,
                                 llvm::Optional<tensorflow::Session*> session) {
@@ -132,7 +133,9 @@ void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
 
   if (pass_config.lower_tensor_list_ops) {
     // TODO(haoliang): Add this pass by default.
-    pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
+    pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
+        /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
+        toco_flags.enable_select_tf_ops()));
   }
 
   // This pass does resource analysis of saved model global tensors and marks
@@ -266,7 +269,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
                                 mlir::OpPassManager* pass_manager,
                                 llvm::Optional<tensorflow::Session*> session) {
   const toco::ModelFlags model_flags;
-  AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session);
+  const toco::TocoFlags toco_flags;
+  AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, pass_manager,
+                             session);
 }
 
 }  // namespace tensorflow
@@ -294,7 +299,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
   mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
 
   // This is needed for control flow support with TF TensorList.
-  pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
+  pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
+      /*allow_tensorlist_pass_through=*/false));
 
   // Saved model pass to mark global tensors immutable.
   pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h
index 8104508a99f..1afa29a4682 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
 #include "tensorflow/core/public/session.h"
 #include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/toco_flags.pb.h"
 
 namespace tensorflow {
 
@@ -30,6 +31,7 @@ namespace tensorflow {
 // imported from saved model version one and utilized for capturing resource
 // variables.
 void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
+                                const toco::TocoFlags& toco_flags,
                                 const mlir::TFL::PassConfig& pass_config,
                                 mlir::OpPassManager* pass_manager,
                                 llvm::Optional<tensorflow::Session*> session);
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 9bbb174e138..2de096b047a 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -79,6 +79,9 @@ struct LowerStaticTensorListPass
     : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
   LowerStaticTensorListPass() = default;
   LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
+  explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through) {
+    this->allow_tensorlist_pass_through = allow_tensorlist_pass_through;
+  }
 
   void runOnOperation() override;
 
@@ -1058,9 +1061,10 @@ void LowerStaticTensorListPass::runOnOperation() {
 
 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
 /// pass.
-std::unique_ptr<OperationPass<ModuleOp>>
-TFL::CreateLowerStaticTensorListPass() {
-  return std::make_unique<LowerStaticTensorListPass>();
+std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
+    bool allow_tensorlist_pass_through) {
+  return std::make_unique<LowerStaticTensorListPass>(
+      allow_tensorlist_pass_through);
 }
 
 static PassRegistration<LowerStaticTensorListPass> pass(
diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h
index 580cdd1103e..3dce44cc9ad 100644
--- a/tensorflow/compiler/mlir/lite/transforms/passes.h
+++ b/tensorflow/compiler/mlir/lite/transforms/passes.h
@@ -46,7 +46,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
 
 // Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
 // pass.
-std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
+std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass(
+    bool allow_tensorlist_pass_through = false);
 
 // Creates an instance of the TensorFlow Lite dialect Quantize pass.
 std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 6478d649c76..1133b49fc28 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -2042,6 +2042,53 @@ class ResourceAndVariantTypes(lite_v2_test_util.ModelTest):
     actual_value = interpreter.get_tensor(output_details[0]['index'])
     self.assertEqual(9.0, actual_value)
 
+  @test_util.run_v2_only
+  def testTensorListWithDynamicSize(self):
+
+    def create_v1_saved_model():
+      saved_model_dir = os.path.join(self.get_temp_dir(),
+                                     'simple_mutable_variable')
+      with tf.Graph().as_default():
+        with tf.compat.v1.Session() as sess:
+          in_tensor = tf.compat.v1.placeholder(
+              shape=[1], dtype=tf.float32, name='input')
+
+          ta = tf.TensorArray(
+              tf.float32, size=0, dynamic_size=True, clear_after_read=False)
+          ta = ta.write(0, 10.0)
+          ta = ta.write(1, 20.0)
+          ta = ta.write(2, 30.0)
+
+          out_tensor = ta.read(0) + ta.read(2)
+
+          inputs = {'x': in_tensor}
+          outputs = {'z': out_tensor}
+          saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+      return saved_model_dir
+
+    saved_model_dir = create_v1_saved_model()
+
+    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
+    converter.target_spec.supported_ops = [
+        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
+    ]
+    tflite_model = converter.convert()
+    self.assertIsNotNone(tflite_model)
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    interpreter.allocate_tensors()
+
+    input_data = np.array([1.0], dtype=np.float32)
+    interpreter.set_tensor(input_details[0]['index'], input_data)
+
+    interpreter.invoke()
+    actual_value = interpreter.get_tensor(output_details[0]['index'])
+    self.assertEqual(40.0, actual_value)
+
 
 if __name__ == '__main__':
   test.main()