From 352a98d60de6f5ca223ec64958ce9c3047df3111 Mon Sep 17 00:00:00 2001 From: Jaesung Chung <jaesung@google.com> Date: Tue, 23 Feb 2021 20:45:19 -0800 Subject: [PATCH] Implement Flex fallback for dynamic TensorList use cases PiperOrigin-RevId: 359198873 Change-Id: Ia0e9bf54380d317d22b13d2394e784d0099d4be0 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../lite/python/tf_tfl_flatbuffer_helpers.cc | 4 +- .../compiler/mlir/lite/tf_tfl_passes.cc | 12 +++-- tensorflow/compiler/mlir/lite/tf_tfl_passes.h | 2 + .../transforms/lower_static_tensor_list.cc | 10 ++-- .../compiler/mlir/lite/transforms/passes.h | 3 +- tensorflow/lite/python/lite_v2_test.py | 47 +++++++++++++++++++ 7 files changed, 70 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 2cd581b3a22..20ff771405e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -976,6 +976,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:core_cpu_base", "//tensorflow/lite/toco:model_flags_proto_cc", + "//tensorflow/lite/toco:toco_flags_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index edde83c046d..90c25d92c74 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -311,8 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::OpPassManager::Nesting::Implicit); ::tensorflow::SetCrashReproducer(pm); - tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm, - session); + tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, + &pm, session); // Convert back to outlined while format for export back to flatbuffer. if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 32ff102ca29..14e4b495855 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -63,6 +63,7 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, } void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager, llvm::Optional<tensorflow::Session*> session) { @@ -132,7 +133,9 @@ void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags, if (pass_config.lower_tensor_list_ops) { // TODO(haoliang): Add this pass by default. - pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); + pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass( + /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() || + toco_flags.enable_select_tf_ops())); } // This pass does resource analysis of saved model global tensors and marks @@ -266,7 +269,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager, llvm::Optional<tensorflow::Session*> session) { const toco::ModelFlags model_flags; - AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session); + const toco::TocoFlags toco_flags; + AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, pass_manager, + session); } } // namespace tensorflow @@ -294,7 +299,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm, mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options); // This is needed for control flow support with TF TensorList. - pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass()); + pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass( + /*allow_tensorlist_pass_through=*/false)); // Saved model pass to mark global tensors immutable. pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h index 8104508a99f..1afa29a4682 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" namespace tensorflow { @@ -30,6 +31,7 @@ namespace tensorflow { // imported from saved model version one and utilized for capturing resource // variables. void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags, + const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager, llvm::Optional<tensorflow::Session*> session); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 9bbb174e138..2de096b047a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -79,6 +79,9 @@ struct LowerStaticTensorListPass : public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> { LowerStaticTensorListPass() = default; LowerStaticTensorListPass(const LowerStaticTensorListPass &) {} + explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through) { + this->allow_tensorlist_pass_through = allow_tensorlist_pass_through; + } void runOnOperation() override; @@ -1058,9 +1061,10 @@ void LowerStaticTensorListPass::runOnOperation() { /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList /// pass. -std::unique_ptr<OperationPass<ModuleOp>> -TFL::CreateLowerStaticTensorListPass() { - return std::make_unique<LowerStaticTensorListPass>(); +std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass( + bool allow_tensorlist_pass_through) { + return std::make_unique<LowerStaticTensorListPass>( + allow_tensorlist_pass_through); } static PassRegistration<LowerStaticTensorListPass> pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 580cdd1103e..3dce44cc9ad 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -46,7 +46,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass( // Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList // pass. -std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass(); +std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass( + bool allow_tensorlist_pass_through = false); // Creates an instance of the TensorFlow Lite dialect Quantize pass. std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass( diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 6478d649c76..1133b49fc28 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -2042,6 +2042,53 @@ class ResourceAndVariantTypes(lite_v2_test_util.ModelTest): actual_value = interpreter.get_tensor(output_details[0]['index']) self.assertEqual(9.0, actual_value) + @test_util.run_v2_only + def testTensorListWithDynamicSize(self): + + def create_v1_saved_model(): + saved_model_dir = os.path.join(self.get_temp_dir(), + 'simple_mutable_variable') + with tf.Graph().as_default(): + with tf.compat.v1.Session() as sess: + in_tensor = tf.compat.v1.placeholder( + shape=[1], dtype=tf.float32, name='input') + + ta = tf.TensorArray( + tf.float32, size=0, dynamic_size=True, clear_after_read=False) + ta = ta.write(0, 10.0) + ta = ta.write(1, 20.0) + ta = ta.write(2, 30.0) + + out_tensor = ta.read(0) + ta.read(2) + + inputs = {'x': in_tensor} + outputs = {'z': out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + saved_model_dir = create_v1_saved_model() + + converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir) + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS + ] + tflite_model = converter.convert() + self.assertIsNotNone(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + interpreter.allocate_tensors() + + input_data = np.array([1.0], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + + interpreter.invoke() + actual_value = interpreter.get_tensor(output_details[0]['index']) + self.assertEqual(40.0, actual_value) + if __name__ == '__main__': test.main()