TFL micro: Add allow_dynamic_tensor flag to TOCO.
PiperOrigin-RevId: 256914004
This commit is contained in:
parent
fa4de2ca7a
commit
dcb3eec82f
@ -500,3 +500,16 @@ tf_cc_test(
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "toco_cmdline_flags_test",
|
||||
srcs = [
|
||||
"toco_cmdline_flags_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":toco_cmdline_flags",
|
||||
":toco_flags_proto_cc",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -171,6 +171,7 @@ struct ParsedTocoFlags {
|
||||
Arg<bool> drop_fake_quant = Arg<bool>(false);
|
||||
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
||||
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
||||
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
|
||||
Arg<bool> post_training_quantize = Arg<bool>(false);
|
||||
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
||||
// Deprecated flags
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -605,6 +605,13 @@ tensorflow::Status Export(
|
||||
/* name */ 0);
|
||||
std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
|
||||
|
||||
// TODO(wangtz): offline memory planning for activation Tensors.
|
||||
if (!params.allow_dynamic_tensors) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"Unsupported flag: allow_dynamic_tensors. Offline memory planning is "
|
||||
"not implemented yet.");
|
||||
}
|
||||
|
||||
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
|
||||
auto description = builder.CreateString("TOCO Converted.");
|
||||
auto new_model_location =
|
||||
|
@ -28,6 +28,7 @@ enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
|
||||
// The parameters for exporting a TFLite model.
|
||||
struct ExportParams {
|
||||
bool allow_custom_ops = false;
|
||||
bool allow_dynamic_tensors = true;
|
||||
bool enable_select_tf_ops = false;
|
||||
QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
|
||||
};
|
||||
|
@ -16,17 +16,19 @@ limitations under the License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
|
||||
namespace toco {
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
class ExportTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -146,6 +148,11 @@ class ExportTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status ExportAndReturnStatus(const ExportParams& params) {
|
||||
string result;
|
||||
return Export(input_model_, &result, params);
|
||||
}
|
||||
|
||||
std::vector<string> ExportAndSummarizeOperators(const ExportParams& params) {
|
||||
std::vector<string> names;
|
||||
|
||||
@ -213,6 +220,17 @@ TEST_F(ExportTest, LoadOperatorsMap) {
|
||||
3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, UnsupportedFunctionality) {
|
||||
AddOperatorsByName({"Conv"});
|
||||
|
||||
ExportParams params;
|
||||
params.allow_dynamic_tensors = false;
|
||||
auto status = ExportAndReturnStatus(params);
|
||||
EXPECT_EQ(status.code(), ::tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Unsupported flag: allow_dynamic_tensors."));
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, Export) {
|
||||
AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"});
|
||||
|
||||
|
@ -124,6 +124,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
|
||||
parsed_flags.allow_custom_ops.default_value(),
|
||||
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
||||
"unsupported TensorFlow ops."),
|
||||
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
|
||||
parsed_flags.allow_dynamic_tensors.default_value(),
|
||||
"Boolean flag indicating whether the converter should allow models "
|
||||
"with dynamic Tensor shape. When set to False, the converter will "
|
||||
"generate runtime memory offsets for activation Tensors (with 128 "
|
||||
"bits alignment) and error out on models with undetermined Tensor "
|
||||
"shape. (Default: True)"),
|
||||
Flag(
|
||||
"drop_control_dependency",
|
||||
parsed_flags.drop_control_dependency.bind(),
|
||||
|
60
tensorflow/lite/toco/toco_cmdline_flags_test.cc
Normal file
60
tensorflow/lite/toco/toco_cmdline_flags_test.cc
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace toco {
|
||||
namespace {
|
||||
|
||||
TEST(TocoCmdlineFlagsTest, DefaultValue) {
|
||||
int argc = 1;
|
||||
// Invariant in ANSI C, len(argv) == argc +1 also argv[argc] == nullptr
|
||||
// TF flag parsing lib is relaying on this invariant.
|
||||
const char* args[] = {"toco", nullptr};
|
||||
|
||||
string message;
|
||||
ParsedTocoFlags result_flags;
|
||||
|
||||
EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags(
|
||||
&argc, const_cast<char**>(args), &message, &result_flags));
|
||||
EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), true);
|
||||
}
|
||||
|
||||
TEST(TocoCmdlineFlagsTest, ParseFlags) {
|
||||
int argc = 2;
|
||||
const char* args[] = {"toco", "--allow_dynamic_tensors=false", nullptr};
|
||||
|
||||
string message;
|
||||
ParsedTocoFlags result_flags;
|
||||
|
||||
EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags(
|
||||
&argc, const_cast<char**>(args), &message, &result_flags));
|
||||
EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), false);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace toco
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
::toco::port::InitGoogleWasDoneElsewhere();
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -38,7 +38,7 @@ enum FileFormat {
|
||||
// of as properties of models, instead describing how models are to be
|
||||
// processed in the context of the present tooling job.
|
||||
//
|
||||
// Next ID to use: 30.
|
||||
// Next ID to use: 31.
|
||||
message TocoFlags {
|
||||
// Input file format
|
||||
optional FileFormat input_format = 1;
|
||||
@ -212,4 +212,10 @@ message TocoFlags {
|
||||
// wish to implement kernels on reduced precision floats for performance
|
||||
// gains.
|
||||
optional bool quantize_to_float16 = 29 [default = false];
|
||||
|
||||
// Boolean flag indicating whether the converter should allow models with
|
||||
// dynamic Tensor shape. When set to False, the converter will generate
|
||||
// runtime memory offsets for activation Tensors (with 128 bits alignment)
|
||||
// and error out on models with undetermined Tensor shape. (Default: True)
|
||||
optional bool allow_dynamic_tensors = 30 [default = true];
|
||||
}
|
||||
|
@ -452,6 +452,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
||||
params.enable_select_tf_ops =
|
||||
toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
|
||||
params.allow_custom_ops = allow_custom_ops;
|
||||
params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors();
|
||||
|
||||
if (toco_flags.post_training_quantize()) {
|
||||
if (toco_flags.quantize_to_float16()) {
|
||||
params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
|
||||
|
Loading…
Reference in New Issue
Block a user