From 6be6d3b7ea72b369a9691bbfd1d0874f1127a3a3 Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Mon, 1 Jun 2020 18:23:56 -0700 Subject: [PATCH] Make it clear that gen_selected_ops support multiple models PiperOrigin-RevId: 314243469 Change-Id: Id19f7926d60d340222b258a3eb53010ddea4dd89 --- tensorflow/lite/build_def.bzl | 11 +++++++--- .../lite/tools/gen_op_registration_main.cc | 21 ++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index fd51ad0a4aa..285824a613f 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -634,7 +634,7 @@ def gen_selected_ops(name, model, namespace = "", **kwargs): Args: name: Name of the generated library. - model: TFLite model to interpret. + model: TFLite models to interpret, expect a list in case of multiple models. namespace: Namespace in which to put RegisterSelectedOps. **kwargs: Additional kwargs to pass to genrule. """ @@ -645,12 +645,17 @@ def gen_selected_ops(name, model, namespace = "", **kwargs): # isinstance is not supported in skylark. if type(model) != type([]): model = [model] + + input_models_args = " --input_models=%s" % ",".join( + ["$(location %s)" % f for f in model], + ) + native.genrule( name = name, srcs = model, outs = [out], - cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") % - (tool, namespace, out, tflite_path[2:]), + cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s %s") % + (tool, namespace, out, tflite_path[2:], input_models_args), tools = [tool], **kwargs ) diff --git a/tensorflow/lite/tools/gen_op_registration_main.cc b/tensorflow/lite/tools/gen_op_registration_main.cc index 410aaabf064..e4398663580 100644 --- a/tensorflow/lite/tools/gen_op_registration_main.cc +++ b/tensorflow/lite/tools/gen_op_registration_main.cc @@ -19,23 +19,24 @@ limitations under the License. #include #include +#include "absl/strings/str_split.h" #include "absl/strings/strip.h" #include "tensorflow/lite/tools/command_line_flags.h" #include "tensorflow/lite/tools/gen_op_registration.h" -const char kInputModelFlag[] = "input_model"; +const char kInputModelFlag[] = "input_models"; const char kNamespace[] = "namespace"; const char kOutputRegistrationFlag[] = "output_registration"; const char kTfLitePathFlag[] = "tflite_path"; const char kForMicro[] = "for_micro"; -void ParseFlagAndInit(int* argc, char** argv, std::string* input_model, +void ParseFlagAndInit(int* argc, char** argv, std::string* input_models, std::string* output_registration, std::string* tflite_path, std::string* namespace_flag, bool* for_micro) { std::vector flag_list = { - tflite::Flag::CreateFlag(kInputModelFlag, input_model, - "path to the tflite model"), + tflite::Flag::CreateFlag(kInputModelFlag, input_models, + "path to the tflite models, separated by comma"), tflite::Flag::CreateFlag(kOutputRegistrationFlag, output_registration, "filename for generated registration code"), tflite::Flag::CreateFlag(kTfLitePathFlag, tflite_path, @@ -144,22 +145,26 @@ void AddOpsFromModel(const std::string& input_model, } // namespace int main(int argc, char** argv) { - std::string input_model; + std::string input_models; std::string output_registration; std::string tflite_path; std::string namespace_flag; bool for_micro = false; - ParseFlagAndInit(&argc, argv, &input_model, &output_registration, + ParseFlagAndInit(&argc, argv, &input_models, &output_registration, &tflite_path, &namespace_flag, &for_micro); tflite::RegisteredOpMap builtin_ops; tflite::RegisteredOpMap custom_ops; - if (!input_model.empty()) { - AddOpsFromModel(input_model, &builtin_ops, &custom_ops); + if (!input_models.empty()) { + std::vector models = absl::StrSplit(input_models, ','); + for (const std::string& input_model : models) { + AddOpsFromModel(input_model, &builtin_ops, &custom_ops); + } } for (int i = 1; i < argc; i++) { AddOpsFromModel(argv[i], &builtin_ops, &custom_ops); } + GenerateFileContent(tflite_path, output_registration, namespace_flag, builtin_ops, custom_ops, for_micro); return 0;