Implement Flex fallback for dynamic TensorList use cases

PiperOrigin-RevId: 359198873
Change-Id: Ia0e9bf54380d317d22b13d2394e784d0099d4be0
This commit is contained in:
Jaesung Chung 2021-02-23 20:45:19 -08:00 committed by TensorFlower Gardener
parent f083f1834c
commit 352a98d60d
7 changed files with 70 additions and 9 deletions

View File

@ -976,6 +976,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",

View File

@ -311,8 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
mlir::OpPassManager::Nesting::Implicit);
::tensorflow::SetCrashReproducer(pm);
tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm,
session);
tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
&pm, session);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());

View File

@ -63,6 +63,7 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
}
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session) {
@ -132,7 +133,9 @@ void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
if (pass_config.lower_tensor_list_ops) {
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
/*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
toco_flags.enable_select_tf_ops()));
}
// This pass does resource analysis of saved model global tensors and marks
@ -266,7 +269,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session) {
const toco::ModelFlags model_flags;
AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session);
const toco::TocoFlags toco_flags;
AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, pass_manager,
session);
}
} // namespace tensorflow
@ -294,7 +299,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
// This is needed for control flow support with TF TensorList.
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
/*allow_tensorlist_pass_through=*/false));
// Saved model pass to mark global tensors immutable.
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
namespace tensorflow {
@ -30,6 +31,7 @@ namespace tensorflow {
// imported from saved model version one and utilized for capturing resource
// variables.
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session);

View File

@ -79,6 +79,9 @@ struct LowerStaticTensorListPass
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
LowerStaticTensorListPass() = default;
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through) {
this->allow_tensorlist_pass_through = allow_tensorlist_pass_through;
}
void runOnOperation() override;
@ -1058,9 +1061,10 @@ void LowerStaticTensorListPass::runOnOperation() {
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass.
std::unique_ptr<OperationPass<ModuleOp>>
TFL::CreateLowerStaticTensorListPass() {
return std::make_unique<LowerStaticTensorListPass>();
std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
bool allow_tensorlist_pass_through) {
return std::make_unique<LowerStaticTensorListPass>(
allow_tensorlist_pass_through);
}
static PassRegistration<LowerStaticTensorListPass> pass(

View File

@ -46,7 +46,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass(
bool allow_tensorlist_pass_through = false);
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(

View File

@ -2042,6 +2042,53 @@ class ResourceAndVariantTypes(lite_v2_test_util.ModelTest):
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual(9.0, actual_value)
@test_util.run_v2_only
def testTensorListWithDynamicSize(self):
def create_v1_saved_model():
saved_model_dir = os.path.join(self.get_temp_dir(),
'simple_mutable_variable')
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
in_tensor = tf.compat.v1.placeholder(
shape=[1], dtype=tf.float32, name='input')
ta = tf.TensorArray(
tf.float32, size=0, dynamic_size=True, clear_after_read=False)
ta = ta.write(0, 10.0)
ta = ta.write(1, 20.0)
ta = ta.write(2, 30.0)
out_tensor = ta.read(0) + ta.read(2)
inputs = {'x': in_tensor}
outputs = {'z': out_tensor}
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
return saved_model_dir
saved_model_dir = create_v1_saved_model()
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
self.assertIsNotNone(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()
input_data = np.array([1.0], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual(40.0, actual_value)
if __name__ == '__main__':
test.main()