TFL micro: Add allow_dynamic_tensor flag to TOCO.
PiperOrigin-RevId: 256914004
This commit is contained in:
parent
fa4de2ca7a
commit
dcb3eec82f
@ -500,3 +500,16 @@ tf_cc_test(
|
|||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "toco_cmdline_flags_test",
|
||||||
|
srcs = [
|
||||||
|
"toco_cmdline_flags_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":toco_cmdline_flags",
|
||||||
|
":toco_flags_proto_cc",
|
||||||
|
"//tensorflow/lite/testing:util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -171,6 +171,7 @@ struct ParsedTocoFlags {
|
|||||||
Arg<bool> drop_fake_quant = Arg<bool>(false);
|
Arg<bool> drop_fake_quant = Arg<bool>(false);
|
||||||
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
||||||
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
||||||
|
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
|
||||||
Arg<bool> post_training_quantize = Arg<bool>(false);
|
Arg<bool> post_training_quantize = Arg<bool>(false);
|
||||||
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
||||||
// Deprecated flags
|
// Deprecated flags
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -605,6 +605,13 @@ tensorflow::Status Export(
|
|||||||
/* name */ 0);
|
/* name */ 0);
|
||||||
std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
|
std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
|
||||||
|
|
||||||
|
// TODO(wangtz): offline memory planning for activation Tensors.
|
||||||
|
if (!params.allow_dynamic_tensors) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"Unsupported flag: allow_dynamic_tensors. Offline memory planning is "
|
||||||
|
"not implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
|
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
|
||||||
auto description = builder.CreateString("TOCO Converted.");
|
auto description = builder.CreateString("TOCO Converted.");
|
||||||
auto new_model_location =
|
auto new_model_location =
|
||||||
|
@ -28,6 +28,7 @@ enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
|
|||||||
// The parameters for exporting a TFLite model.
|
// The parameters for exporting a TFLite model.
|
||||||
struct ExportParams {
|
struct ExportParams {
|
||||||
bool allow_custom_ops = false;
|
bool allow_custom_ops = false;
|
||||||
|
bool allow_dynamic_tensors = true;
|
||||||
bool enable_select_tf_ops = false;
|
bool enable_select_tf_ops = false;
|
||||||
QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
|
QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
|
||||||
};
|
};
|
||||||
|
@ -16,17 +16,19 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
||||||
#include "tensorflow/lite/toco/tflite/operator.h"
|
#include "tensorflow/lite/toco/tflite/operator.h"
|
||||||
#include "tensorflow/lite/toco/tflite/types.h"
|
#include "tensorflow/lite/toco/tflite/types.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
class ExportTest : public ::testing::Test {
|
class ExportTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -146,6 +148,11 @@ class ExportTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::Status ExportAndReturnStatus(const ExportParams& params) {
|
||||||
|
string result;
|
||||||
|
return Export(input_model_, &result, params);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) {
|
std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) {
|
||||||
std::vector<string> names;
|
std::vector<string> names;
|
||||||
|
|
||||||
@ -213,6 +220,17 @@ TEST_F(ExportTest, LoadOperatorsMap) {
|
|||||||
3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
|
3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ExportTest, UnsupportedFunctionality) {
|
||||||
|
AddOperatorsByName({"Conv"});
|
||||||
|
|
||||||
|
ExportParams params;
|
||||||
|
params.allow_dynamic_tensors = false;
|
||||||
|
auto status = ExportAndReturnStatus(params);
|
||||||
|
EXPECT_EQ(status.code(), ::tensorflow::error::UNIMPLEMENTED);
|
||||||
|
EXPECT_THAT(status.error_message(),
|
||||||
|
HasSubstr("Unsupported flag: allow_dynamic_tensors."));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ExportTest, Export) {
|
TEST_F(ExportTest, Export) {
|
||||||
AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
|
AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
|
||||||
|
|
||||||
|
@ -124,6 +124,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
|
|||||||
parsed_flags.allow_custom_ops.default_value(),
|
parsed_flags.allow_custom_ops.default_value(),
|
||||||
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
||||||
"unsupported TensorFlow ops."),
|
"unsupported TensorFlow ops."),
|
||||||
|
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
|
||||||
|
parsed_flags.allow_dynamic_tensors.default_value(),
|
||||||
|
"Boolean flag indicating whether the converter should allow models "
|
||||||
|
"with dynamic Tensor shape. When set to False, the converter will "
|
||||||
|
"generate runtime memory offsets for activation Tensors (with 128 "
|
||||||
|
"bits alignment) and error out on models with undetermined Tensor "
|
||||||
|
"shape. (Default: True)"),
|
||||||
Flag(
|
Flag(
|
||||||
"drop_control_dependency",
|
"drop_control_dependency",
|
||||||
parsed_flags.drop_control_dependency.bind(),
|
parsed_flags.drop_control_dependency.bind(),
|
||||||
|
60
tensorflow/lite/toco/toco_cmdline_flags_test.cc
Normal file
60
tensorflow/lite/toco/toco_cmdline_flags_test.cc
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/testing/util.h"
|
||||||
|
|
||||||
|
namespace toco {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(TocoCmdlineFlagsTest, DefaultValue) {
|
||||||
|
int argc = 1;
|
||||||
|
// Invariant in ANSI C, len(argv) == argc +1 also argv[argc] == nullptr
|
||||||
|
// TF flag parsing lib is relaying on this invariant.
|
||||||
|
const char* args[] = {"toco", nullptr};
|
||||||
|
|
||||||
|
string message;
|
||||||
|
ParsedTocoFlags result_flags;
|
||||||
|
|
||||||
|
EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags(
|
||||||
|
&argc, const_cast<char**>(args), &message, &result_flags));
|
||||||
|
EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TocoCmdlineFlagsTest, ParseFlags) {
|
||||||
|
int argc = 2;
|
||||||
|
const char* args[] = {"toco", "--allow_dynamic_tensors=false", nullptr};
|
||||||
|
|
||||||
|
string message;
|
||||||
|
ParsedTocoFlags result_flags;
|
||||||
|
|
||||||
|
EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags(
|
||||||
|
&argc, const_cast<char**>(args), &message, &result_flags));
|
||||||
|
EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace toco
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
::toco::port::InitGoogleWasDoneElsewhere();
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -38,7 +38,7 @@ enum FileFormat {
|
|||||||
// of as properties of models, instead describing how models are to be
|
// of as properties of models, instead describing how models are to be
|
||||||
// processed in the context of the present tooling job.
|
// processed in the context of the present tooling job.
|
||||||
//
|
//
|
||||||
// Next ID to use: 30.
|
// Next ID to use: 31.
|
||||||
message TocoFlags {
|
message TocoFlags {
|
||||||
// Input file format
|
// Input file format
|
||||||
optional FileFormat input_format = 1;
|
optional FileFormat input_format = 1;
|
||||||
@ -212,4 +212,10 @@ message TocoFlags {
|
|||||||
// wish to implement kernels on reduced precision floats for performance
|
// wish to implement kernels on reduced precision floats for performance
|
||||||
// gains.
|
// gains.
|
||||||
optional bool quantize_to_float16 = 29 [default = false];
|
optional bool quantize_to_float16 = 29 [default = false];
|
||||||
|
|
||||||
|
// Boolean flag indicating whether the converter should allow models with
|
||||||
|
// dynamic Tensor shape. When set to False, the converter will generate
|
||||||
|
// runtime memory offsets for activation Tensors (with 128 bits alignment)
|
||||||
|
// and error out on models with undetermined Tensor shape. (Default: True)
|
||||||
|
optional bool allow_dynamic_tensors = 30 [default = true];
|
||||||
}
|
}
|
||||||
|
@ -452,6 +452,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
|||||||
params.enable_select_tf_ops =
|
params.enable_select_tf_ops =
|
||||||
toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
|
toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
|
||||||
params.allow_custom_ops = allow_custom_ops;
|
params.allow_custom_ops = allow_custom_ops;
|
||||||
|
params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors();
|
||||||
|
|
||||||
if (toco_flags.post_training_quantize()) {
|
if (toco_flags.post_training_quantize()) {
|
||||||
if (toco_flags.quantize_to_float16()) {
|
if (toco_flags.quantize_to_float16()) {
|
||||||
params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
|
params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
|
||||||
|
Loading…
Reference in New Issue
Block a user