Make it clear that gen_selected_ops support multiple models
PiperOrigin-RevId: 314243469 Change-Id: Id19f7926d60d340222b258a3eb53010ddea4dd89
This commit is contained in:
parent
5bca62e44a
commit
6be6d3b7ea
@ -634,7 +634,7 @@ def gen_selected_ops(name, model, namespace = "", **kwargs):
|
||||
|
||||
Args:
|
||||
name: Name of the generated library.
|
||||
model: TFLite model to interpret.
|
||||
model: TFLite models to interpret, expect a list in case of multiple models.
|
||||
namespace: Namespace in which to put RegisterSelectedOps.
|
||||
**kwargs: Additional kwargs to pass to genrule.
|
||||
"""
|
||||
@ -645,12 +645,17 @@ def gen_selected_ops(name, model, namespace = "", **kwargs):
|
||||
# isinstance is not supported in skylark.
|
||||
if type(model) != type([]):
|
||||
model = [model]
|
||||
|
||||
input_models_args = " --input_models=%s" % ",".join(
|
||||
["$(location %s)" % f for f in model],
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = model,
|
||||
outs = [out],
|
||||
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") %
|
||||
(tool, namespace, out, tflite_path[2:]),
|
||||
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s %s") %
|
||||
(tool, namespace, out, tflite_path[2:], input_models_args),
|
||||
tools = [tool],
|
||||
**kwargs
|
||||
)
|
||||
|
@ -19,23 +19,24 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||
|
||||
const char kInputModelFlag[] = "input_model";
|
||||
const char kInputModelFlag[] = "input_models";
|
||||
const char kNamespace[] = "namespace";
|
||||
const char kOutputRegistrationFlag[] = "output_registration";
|
||||
const char kTfLitePathFlag[] = "tflite_path";
|
||||
const char kForMicro[] = "for_micro";
|
||||
|
||||
void ParseFlagAndInit(int* argc, char** argv, std::string* input_model,
|
||||
void ParseFlagAndInit(int* argc, char** argv, std::string* input_models,
|
||||
std::string* output_registration,
|
||||
std::string* tflite_path, std::string* namespace_flag,
|
||||
bool* for_micro) {
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kInputModelFlag, input_model,
|
||||
"path to the tflite model"),
|
||||
tflite::Flag::CreateFlag(kInputModelFlag, input_models,
|
||||
"path to the tflite models, separated by comma"),
|
||||
tflite::Flag::CreateFlag(kOutputRegistrationFlag, output_registration,
|
||||
"filename for generated registration code"),
|
||||
tflite::Flag::CreateFlag(kTfLitePathFlag, tflite_path,
|
||||
@ -144,22 +145,26 @@ void AddOpsFromModel(const std::string& input_model,
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string input_model;
|
||||
std::string input_models;
|
||||
std::string output_registration;
|
||||
std::string tflite_path;
|
||||
std::string namespace_flag;
|
||||
bool for_micro = false;
|
||||
ParseFlagAndInit(&argc, argv, &input_model, &output_registration,
|
||||
ParseFlagAndInit(&argc, argv, &input_models, &output_registration,
|
||||
&tflite_path, &namespace_flag, &for_micro);
|
||||
|
||||
tflite::RegisteredOpMap builtin_ops;
|
||||
tflite::RegisteredOpMap custom_ops;
|
||||
if (!input_model.empty()) {
|
||||
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
|
||||
if (!input_models.empty()) {
|
||||
std::vector<std::string> models = absl::StrSplit(input_models, ',');
|
||||
for (const std::string& input_model : models) {
|
||||
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < argc; i++) {
|
||||
AddOpsFromModel(argv[i], &builtin_ops, &custom_ops);
|
||||
}
|
||||
|
||||
GenerateFileContent(tflite_path, output_registration, namespace_flag,
|
||||
builtin_ops, custom_ops, for_micro);
|
||||
return 0;
|
||||
|
Loading…
x
Reference in New Issue
Block a user