Implement Flex fallback for dynamic TensorList use cases
PiperOrigin-RevId: 359198873 Change-Id: Ia0e9bf54380d317d22b13d2394e784d0099d4be0
This commit is contained in:
parent
f083f1834c
commit
352a98d60d
@ -976,6 +976,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -311,8 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
mlir::OpPassManager::Nesting::Implicit);
|
||||
::tensorflow::SetCrashReproducer(pm);
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm,
|
||||
session);
|
||||
tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
|
||||
&pm, session);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
|
@ -63,6 +63,7 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session) {
|
||||
@ -132,7 +133,9 @@ void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
|
||||
/*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
|
||||
toco_flags.enable_select_tf_ops()));
|
||||
}
|
||||
|
||||
// This pass does resource analysis of saved model global tensors and marks
|
||||
@ -266,7 +269,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session) {
|
||||
const toco::ModelFlags model_flags;
|
||||
AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session);
|
||||
const toco::TocoFlags toco_flags;
|
||||
AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config, pass_manager,
|
||||
session);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
@ -294,7 +299,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
|
||||
|
||||
// This is needed for control flow support with TF TensorList.
|
||||
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
|
||||
/*allow_tensorlist_pass_through=*/false));
|
||||
|
||||
// Saved model pass to mark global tensors immutable.
|
||||
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -30,6 +31,7 @@ namespace tensorflow {
|
||||
// imported from saved model version one and utilized for capturing resource
|
||||
// variables.
|
||||
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session);
|
||||
|
@ -79,6 +79,9 @@ struct LowerStaticTensorListPass
|
||||
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
|
||||
LowerStaticTensorListPass() = default;
|
||||
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
|
||||
explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through) {
|
||||
this->allow_tensorlist_pass_through = allow_tensorlist_pass_through;
|
||||
}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
@ -1058,9 +1061,10 @@ void LowerStaticTensorListPass::runOnOperation() {
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
/// pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
TFL::CreateLowerStaticTensorListPass() {
|
||||
return std::make_unique<LowerStaticTensorListPass>();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
|
||||
bool allow_tensorlist_pass_through) {
|
||||
return std::make_unique<LowerStaticTensorListPass>(
|
||||
allow_tensorlist_pass_through);
|
||||
}
|
||||
|
||||
static PassRegistration<LowerStaticTensorListPass> pass(
|
||||
|
@ -46,7 +46,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass(
|
||||
bool allow_tensorlist_pass_through = false);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass(
|
||||
|
@ -2042,6 +2042,53 @@ class ResourceAndVariantTypes(lite_v2_test_util.ModelTest):
|
||||
actual_value = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertEqual(9.0, actual_value)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTensorListWithDynamicSize(self):
|
||||
|
||||
def create_v1_saved_model():
|
||||
saved_model_dir = os.path.join(self.get_temp_dir(),
|
||||
'simple_mutable_variable')
|
||||
with tf.Graph().as_default():
|
||||
with tf.compat.v1.Session() as sess:
|
||||
in_tensor = tf.compat.v1.placeholder(
|
||||
shape=[1], dtype=tf.float32, name='input')
|
||||
|
||||
ta = tf.TensorArray(
|
||||
tf.float32, size=0, dynamic_size=True, clear_after_read=False)
|
||||
ta = ta.write(0, 10.0)
|
||||
ta = ta.write(1, 20.0)
|
||||
ta = ta.write(2, 30.0)
|
||||
|
||||
out_tensor = ta.read(0) + ta.read(2)
|
||||
|
||||
inputs = {'x': in_tensor}
|
||||
outputs = {'z': out_tensor}
|
||||
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
|
||||
return saved_model_dir
|
||||
|
||||
saved_model_dir = create_v1_saved_model()
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
|
||||
converter.target_spec.supported_ops = [
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
tflite_model = converter.convert()
|
||||
self.assertIsNotNone(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
input_details = interpreter.get_input_details()
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_data = np.array([1.0], dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||
|
||||
interpreter.invoke()
|
||||
actual_value = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertEqual(40.0, actual_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user