TFL micro: Add allow_dynamic_tensor flag to TOCO.

PiperOrigin-RevId: 256914004
This commit is contained in:
Tiezhen WANG 2019-07-07 23:31:50 -07:00 committed by TensorFlower Gardener
parent fa4de2ca7a
commit dcb3eec82f
9 changed files with 118 additions and 3 deletions

View File

@ -500,3 +500,16 @@ tf_cc_test(
"@com_google_googletest//:gtest",
],
)
tf_cc_test(
name = "toco_cmdline_flags_test",
srcs = [
"toco_cmdline_flags_test.cc",
],
deps = [
":toco_cmdline_flags",
":toco_flags_proto_cc",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)

View File

@ -171,6 +171,7 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
Arg<bool> post_training_quantize = Arg<bool>(false);
Arg<bool> quantize_to_float16 = Arg<bool>(false);
// Deprecated flags

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -605,6 +605,13 @@ tensorflow::Status Export(
/* name */ 0);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
// TODO(wangtz): offline memory planning for activation Tensors.
if (!params.allow_dynamic_tensors) {
return tensorflow::errors::Unimplemented(
"Unsupported flag: allow_dynamic_tensors. Offline memory planning is "
"not implemented yet.");
}
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
auto description = builder.CreateString("TOCO Converted.");
auto new_model_location =

View File

@ -28,6 +28,7 @@ enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
// The parameters for exporting a TFLite model.
struct ExportParams {
bool allow_custom_ops = false;
bool allow_dynamic_tensors = true;
bool enable_select_tf_ops = false;
QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
};

View File

@ -16,17 +16,19 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/lite/toco/tflite/operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
#include "tensorflow/core/framework/node_def.pb.h"
namespace toco {
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::HasSubstr;
class ExportTest : public ::testing::Test {
protected:
@ -146,6 +148,11 @@ class ExportTest : public ::testing::Test {
}
}
tensorflow::Status ExportAndReturnStatus(const ExportParams& params) {
string result;
return Export(input_model_, &result, params);
}
std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) {
std::vector<string> names;
@ -213,6 +220,17 @@ TEST_F(ExportTest, LoadOperatorsMap) {
3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
}
TEST_F(ExportTest, UnsupportedFunctionality) {
AddOperatorsByName({"Conv"});
ExportParams params;
params.allow_dynamic_tensors = false;
auto status = ExportAndReturnStatus(params);
EXPECT_EQ(status.code(), ::tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("Unsupported flag: allow_dynamic_tensors."));
}
TEST_F(ExportTest, Export) {
AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});

View File

@ -124,6 +124,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.allow_custom_ops.default_value(),
"If true, allow TOCO to create TF Lite Custom operators for all the "
"unsupported TensorFlow ops."),
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
parsed_flags.allow_dynamic_tensors.default_value(),
"Boolean flag indicating whether the converter should allow models "
"with dynamic Tensor shape. When set to False, the converter will "
"generate runtime memory offsets for activation Tensors (with 128 "
"bits alignment) and error out on models with undetermined Tensor "
"shape. (Default: True)"),
Flag(
"drop_control_dependency",
parsed_flags.drop_control_dependency.bind(),

View File

@ -0,0 +1,60 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
#include <string>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
namespace toco {
namespace {
TEST(TocoCmdlineFlagsTest, DefaultValue) {
int argc = 1;
// Invariant in ANSI C, len(argv) == argc +1 also argv[argc] == nullptr
// TF flag parsing lib is relaying on this invariant.
const char* args[] = {"toco", nullptr};
string message;
ParsedTocoFlags result_flags;
EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags(
&argc, const_cast<char**>(args), &message, &result_flags));
EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), true);
}
TEST(TocoCmdlineFlagsTest, ParseFlags) {
int argc = 2;
const char* args[] = {"toco", "--allow_dynamic_tensors=false", nullptr};
string message;
ParsedTocoFlags result_flags;
EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags(
&argc, const_cast<char**>(args), &message, &result_flags));
EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), false);
}
} // namespace
} // namespace toco
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
::toco::port::InitGoogleWasDoneElsewhere();
return RUN_ALL_TESTS();
}

View File

@ -38,7 +38,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
// Next ID to use: 30.
// Next ID to use: 31.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@ -212,4 +212,10 @@ message TocoFlags {
// wish to implement kernels on reduced precision floats for performance
// gains.
optional bool quantize_to_float16 = 29 [default = false];
// Boolean flag indicating whether the converter should allow models with
// dynamic Tensor shape. When set to False, the converter will generate
// runtime memory offsets for activation Tensors (with 128 bits alignment)
// and error out on models with undetermined Tensor shape. (Default: True)
optional bool allow_dynamic_tensors = 30 [default = true];
}

View File

@ -452,6 +452,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
params.enable_select_tf_ops =
toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
params.allow_custom_ops = allow_custom_ops;
params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors();
if (toco_flags.post_training_quantize()) {
if (toco_flags.quantize_to_float16()) {
params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;