Make it clear that gen_selected_ops support multiple models

PiperOrigin-RevId: 314243469
Change-Id: Id19f7926d60d340222b258a3eb53010ddea4dd89
This commit is contained in:
Thai Nguyen 2020-06-01 18:23:56 -07:00 committed by TensorFlower Gardener
parent 5bca62e44a
commit 6be6d3b7ea
2 changed files with 21 additions and 11 deletions

View File

@ -634,7 +634,7 @@ def gen_selected_ops(name, model, namespace = "", **kwargs):
Args:
name: Name of the generated library.
model: TFLite model to interpret.
model: TFLite models to interpret, expect a list in case of multiple models.
namespace: Namespace in which to put RegisterSelectedOps.
**kwargs: Additional kwargs to pass to genrule.
"""
@ -645,12 +645,17 @@ def gen_selected_ops(name, model, namespace = "", **kwargs):
# isinstance is not supported in skylark.
if type(model) != type([]):
model = [model]
input_models_args = " --input_models=%s" % ",".join(
["$(location %s)" % f for f in model],
)
native.genrule(
name = name,
srcs = model,
outs = [out],
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") %
(tool, namespace, out, tflite_path[2:]),
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s %s") %
(tool, namespace, out, tflite_path[2:], input_models_args),
tools = [tool],
**kwargs
)

View File

@ -19,23 +19,24 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
const char kInputModelFlag[] = "input_model";
const char kInputModelFlag[] = "input_models";
const char kNamespace[] = "namespace";
const char kOutputRegistrationFlag[] = "output_registration";
const char kTfLitePathFlag[] = "tflite_path";
const char kForMicro[] = "for_micro";
void ParseFlagAndInit(int* argc, char** argv, std::string* input_model,
void ParseFlagAndInit(int* argc, char** argv, std::string* input_models,
std::string* output_registration,
std::string* tflite_path, std::string* namespace_flag,
bool* for_micro) {
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kInputModelFlag, input_model,
"path to the tflite model"),
tflite::Flag::CreateFlag(kInputModelFlag, input_models,
"path to the tflite models, separated by comma"),
tflite::Flag::CreateFlag(kOutputRegistrationFlag, output_registration,
"filename for generated registration code"),
tflite::Flag::CreateFlag(kTfLitePathFlag, tflite_path,
@ -144,22 +145,26 @@ void AddOpsFromModel(const std::string& input_model,
} // namespace
int main(int argc, char** argv) {
std::string input_model;
std::string input_models;
std::string output_registration;
std::string tflite_path;
std::string namespace_flag;
bool for_micro = false;
ParseFlagAndInit(&argc, argv, &input_model, &output_registration,
ParseFlagAndInit(&argc, argv, &input_models, &output_registration,
&tflite_path, &namespace_flag, &for_micro);
tflite::RegisteredOpMap builtin_ops;
tflite::RegisteredOpMap custom_ops;
if (!input_model.empty()) {
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
if (!input_models.empty()) {
std::vector<std::string> models = absl::StrSplit(input_models, ',');
for (const std::string& input_model : models) {
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
}
}
for (int i = 1; i < argc; i++) {
AddOpsFromModel(argv[i], &builtin_ops, &custom_ops);
}
GenerateFileContent(tflite_path, output_registration, namespace_flag,
builtin_ops, custom_ops, for_micro);
return 0;