diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 80382864c71..43714fcf902 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -500,3 +500,16 @@ tf_cc_test( "@com_google_googletest//:gtest", ], ) + +tf_cc_test( + name = "toco_cmdline_flags_test", + srcs = [ + "toco_cmdline_flags_test.cc", + ], + deps = [ + ":toco_cmdline_flags", + ":toco_flags_proto_cc", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/toco/args.h b/tensorflow/lite/toco/args.h index 1003a157e42..6b6bb78be55 100644 --- a/tensorflow/lite/toco/args.h +++ b/tensorflow/lite/toco/args.h @@ -171,6 +171,7 @@ struct ParsedTocoFlags { Arg drop_fake_quant = Arg(false); Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); + Arg allow_dynamic_tensors = Arg(true); Arg post_training_quantize = Arg(false); Arg quantize_to_float16 = Arg(false); // Deprecated flags diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 9b39569d9aa..c32466bc1f3 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -605,6 +605,13 @@ tensorflow::Status Export( /* name */ 0); std::vector> subgraphs = {subgraph}; + // TODO(wangtz): offline memory planning for activation Tensors. + if (!params.allow_dynamic_tensors) { + return tensorflow::errors::Unimplemented( + "Unsupported flag: allow_dynamic_tensors. Offline memory planning is " + "not implemented yet."); + } + auto buffers = ExportBuffers(model, buffers_to_write, &builder); auto description = builder.CreateString("TOCO Converted."); auto new_model_location = diff --git a/tensorflow/lite/toco/tflite/export.h b/tensorflow/lite/toco/tflite/export.h index 3a6031d22b8..3af77ffcf43 100644 --- a/tensorflow/lite/toco/tflite/export.h +++ b/tensorflow/lite/toco/tflite/export.h @@ -28,6 +28,7 @@ enum class QuantizedBufferType { NONE, INT8, FLOAT16 }; // The parameters for exporting a TFLite model. struct ExportParams { bool allow_custom_ops = false; + bool allow_dynamic_tensors = true; bool enable_select_tf_ops = false; QuantizedBufferType quantize_weights = QuantizedBufferType::NONE; }; diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index bbebf46a3b9..0ae6104f8f9 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -16,17 +16,19 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/tflite/builtin_operator.h" #include "tensorflow/lite/toco/tflite/operator.h" #include "tensorflow/lite/toco/tflite/types.h" -#include "tensorflow/core/framework/node_def.pb.h" namespace toco { namespace tflite { namespace { using ::testing::ElementsAre; +using ::testing::HasSubstr; class ExportTest : public ::testing::Test { protected: @@ -146,6 +148,11 @@ class ExportTest : public ::testing::Test { } } + tensorflow::Status ExportAndReturnStatus(const ExportParams& params) { + string result; + return Export(input_model_, &result, params); + } + std::vector ExportAndSummarizeOperators(const ExportParams& params) { std::vector names; @@ -213,6 +220,17 @@ TEST_F(ExportTest, LoadOperatorsMap) { 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]); } +TEST_F(ExportTest, UnsupportedFunctionality) { + AddOperatorsByName({"Conv"}); + + ExportParams params; + params.allow_dynamic_tensors = false; + auto status = ExportAndReturnStatus(params); + EXPECT_EQ(status.code(), ::tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("Unsupported flag: allow_dynamic_tensors.")); +} + TEST_F(ExportTest, Export) { AddOperatorsByName({"Conv", "Add", "MyCrazyOp", "Sub"}); diff --git a/tensorflow/lite/toco/toco_cmdline_flags.cc b/tensorflow/lite/toco/toco_cmdline_flags.cc index c36b3de7748..d21f8d14112 100644 --- a/tensorflow/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/lite/toco/toco_cmdline_flags.cc @@ -124,6 +124,13 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.allow_custom_ops.default_value(), "If true, allow TOCO to create TF Lite Custom operators for all the " "unsupported TensorFlow ops."), + Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(), + parsed_flags.allow_dynamic_tensors.default_value(), + "Boolean flag indicating whether the converter should allow models " + "with dynamic Tensor shape. When set to False, the converter will " + "generate runtime memory offsets for activation Tensors (with 128 " + "bits alignment) and error out on models with undetermined Tensor " + "shape. (Default: True)"), Flag( "drop_control_dependency", parsed_flags.drop_control_dependency.bind(), diff --git a/tensorflow/lite/toco/toco_cmdline_flags_test.cc b/tensorflow/lite/toco/toco_cmdline_flags_test.cc new file mode 100644 index 00000000000..a1066e063bc --- /dev/null +++ b/tensorflow/lite/toco/toco_cmdline_flags_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/toco/toco_cmdline_flags.h" + +#include + +#include +#include "tensorflow/lite/testing/util.h" + +namespace toco { +namespace { + +TEST(TocoCmdlineFlagsTest, DefaultValue) { + int argc = 1; + // Invariant in ANSI C, len(argv) == argc +1 also argv[argc] == nullptr + // TF flag parsing lib is relaying on this invariant. + const char* args[] = {"toco", nullptr}; + + string message; + ParsedTocoFlags result_flags; + + EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags( + &argc, const_cast(args), &message, &result_flags)); + EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), true); +} + +TEST(TocoCmdlineFlagsTest, ParseFlags) { + int argc = 2; + const char* args[] = {"toco", "--allow_dynamic_tensors=false", nullptr}; + + string message; + ParsedTocoFlags result_flags; + + EXPECT_TRUE(ParseTocoFlagsFromCommandLineFlags( + &argc, const_cast(args), &message, &result_flags)); + EXPECT_EQ(result_flags.allow_dynamic_tensors.value(), false); +} + +} // namespace +} // namespace toco + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + ::toco::port::InitGoogleWasDoneElsewhere(); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index 50e9d332749..8e3550ded13 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -38,7 +38,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 30. +// Next ID to use: 31. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -212,4 +212,10 @@ message TocoFlags { // wish to implement kernels on reduced precision floats for performance // gains. optional bool quantize_to_float16 = 29 [default = false]; + + // Boolean flag indicating whether the converter should allow models with + // dynamic Tensor shape. When set to False, the converter will generate + // runtime memory offsets for activation Tensors (with 128 bits alignment) + // and error out on models with undetermined Tensor shape. (Default: True) + optional bool allow_dynamic_tensors = 30 [default = true]; } diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index 96f9d7602f1..020d228ad82 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -452,6 +452,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, params.enable_select_tf_ops = toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops(); params.allow_custom_ops = allow_custom_ops; + params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors(); + if (toco_flags.post_training_quantize()) { if (toco_flags.quantize_to_float16()) { params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;