346 lines
18 KiB
C++
346 lines
18 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/numbers.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "absl/strings/str_split.h"
|
|
#include "absl/strings/strip.h"
|
|
#include "absl/types/optional.h"
|
|
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
|
|
#include "tensorflow/lite/toco/toco_port.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/core/util/command_line_flags.h"
|
|
|
|
namespace toco {
|
|
|
|
bool ParseTocoFlagsFromCommandLineFlags(
|
|
int* argc, char* argv[], string* msg,
|
|
ParsedTocoFlags* parsed_toco_flags_ptr) {
|
|
using tensorflow::Flag;
|
|
ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
|
|
std::vector<tensorflow::Flag> flags = {
|
|
Flag("input_file", parsed_flags.input_file.bind(),
|
|
parsed_flags.input_file.default_value(),
|
|
"Input file (model of any supported format). For Protobuf "
|
|
"formats, both text and binary are supported regardless of file "
|
|
"extension."),
|
|
Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
|
|
parsed_flags.savedmodel_directory.default_value(),
|
|
"Deprecated. Full path to the directory containing the SavedModel."),
|
|
Flag("output_file", parsed_flags.output_file.bind(),
|
|
parsed_flags.output_file.default_value(),
|
|
"Output file. "
|
|
"For Protobuf formats, the binary format will be used."),
|
|
Flag("input_format", parsed_flags.input_format.bind(),
|
|
parsed_flags.input_format.default_value(),
|
|
"Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
|
|
Flag("output_format", parsed_flags.output_format.bind(),
|
|
parsed_flags.output_format.default_value(),
|
|
"Output file format. "
|
|
"One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
|
|
Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
|
|
parsed_flags.savedmodel_tagset.default_value(),
|
|
"Deprecated. Comma-separated set of tags identifying the "
|
|
"MetaGraphDef within the SavedModel to analyze. All tags in the tag "
|
|
"set must be specified."),
|
|
Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
|
|
parsed_flags.default_ranges_min.default_value(),
|
|
"If defined, will be used as the default value for the min bound "
|
|
"of min/max ranges used for quantization of uint8 arrays."),
|
|
Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
|
|
parsed_flags.default_ranges_max.default_value(),
|
|
"If defined, will be used as the default value for the max bound "
|
|
"of min/max ranges used for quantization of uint8 arrays."),
|
|
Flag("default_int16_ranges_min",
|
|
parsed_flags.default_int16_ranges_min.bind(),
|
|
parsed_flags.default_int16_ranges_min.default_value(),
|
|
"If defined, will be used as the default value for the min bound "
|
|
"of min/max ranges used for quantization of int16 arrays."),
|
|
Flag("default_int16_ranges_max",
|
|
parsed_flags.default_int16_ranges_max.bind(),
|
|
parsed_flags.default_int16_ranges_max.default_value(),
|
|
"If defined, will be used as the default value for the max bound "
|
|
"of min/max ranges used for quantization of int16 arrays."),
|
|
Flag("inference_type", parsed_flags.inference_type.bind(),
|
|
parsed_flags.inference_type.default_value(),
|
|
"Target data type of arrays in the output file (for input_arrays, "
|
|
"this may be overridden by inference_input_type). "
|
|
"One of FLOAT, QUANTIZED_UINT8."),
|
|
Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
|
|
parsed_flags.inference_input_type.default_value(),
|
|
"Target data type of input arrays. "
|
|
"If not specified, inference_type is used. "
|
|
"One of FLOAT, QUANTIZED_UINT8."),
|
|
Flag("input_type", parsed_flags.input_type.bind(),
|
|
parsed_flags.input_type.default_value(),
|
|
"Deprecated ambiguous flag that set both --input_data_types and "
|
|
"--inference_input_type."),
|
|
Flag("input_types", parsed_flags.input_types.bind(),
|
|
parsed_flags.input_types.default_value(),
|
|
"Deprecated ambiguous flag that set both --input_data_types and "
|
|
"--inference_input_type. Was meant to be a "
|
|
"comma-separated list, but this was deprecated before "
|
|
"multiple-input-types was ever properly supported."),
|
|
|
|
Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
|
|
parsed_flags.drop_fake_quant.default_value(),
|
|
"Ignore and discard FakeQuant nodes. For instance, to "
|
|
"generate plain float code without fake-quantization from a "
|
|
"quantized graph."),
|
|
Flag(
|
|
"reorder_across_fake_quant",
|
|
parsed_flags.reorder_across_fake_quant.bind(),
|
|
parsed_flags.reorder_across_fake_quant.default_value(),
|
|
"Normally, FakeQuant nodes must be strict boundaries for graph "
|
|
"transformations, in order to ensure that quantized inference has "
|
|
"the exact same arithmetic behavior as quantized training --- which "
|
|
"is the whole point of quantized training and of FakeQuant nodes in "
|
|
"the first place. "
|
|
"However, that entails subtle requirements on where exactly "
|
|
"FakeQuant nodes must be placed in the graph. Some quantized graphs "
|
|
"have FakeQuant nodes at unexpected locations, that prevent graph "
|
|
"transformations that are necessary in order to generate inference "
|
|
"code for these graphs. Such graphs should be fixed, but as a "
|
|
"temporary work-around, setting this reorder_across_fake_quant flag "
|
|
"allows TOCO to perform necessary graph transformaitons on them, "
|
|
"at the cost of no longer faithfully matching inference and training "
|
|
"arithmetic."),
|
|
Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
|
|
parsed_flags.allow_custom_ops.default_value(),
|
|
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
|
"unsupported TensorFlow ops."),
|
|
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
|
|
parsed_flags.allow_dynamic_tensors.default_value(),
|
|
"Boolean flag indicating whether the converter should allow models "
|
|
"with dynamic Tensor shape. When set to False, the converter will "
|
|
"generate runtime memory offsets for activation Tensors (with 128 "
|
|
"bits alignment) and error out on models with undetermined Tensor "
|
|
"shape. (Default: True)"),
|
|
Flag(
|
|
"drop_control_dependency",
|
|
parsed_flags.drop_control_dependency.bind(),
|
|
parsed_flags.drop_control_dependency.default_value(),
|
|
"If true, ignore control dependency requirements in input TensorFlow "
|
|
"GraphDef. Otherwise an error will be raised upon control dependency "
|
|
"inputs."),
|
|
Flag("debug_disable_recurrent_cell_fusion",
|
|
parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
|
|
parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
|
|
"If true, disable fusion of known identifiable cell subgraphs into "
|
|
"cells. This includes, for example, specific forms of LSTM cell."),
|
|
Flag("propagate_fake_quant_num_bits",
|
|
parsed_flags.propagate_fake_quant_num_bits.bind(),
|
|
parsed_flags.propagate_fake_quant_num_bits.default_value(),
|
|
"If true, use FakeQuant* operator num_bits attributes to adjust "
|
|
"array data_types."),
|
|
Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
|
|
parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
|
|
parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
|
|
.default_value(),
|
|
"Some fast uint8 GEMM kernels require uint8 weights to avoid the "
|
|
"value 0. This flag allows nudging them to 1 to allow proceeding, "
|
|
"with moderate inaccuracy."),
|
|
Flag("dedupe_array_min_size_bytes",
|
|
parsed_flags.dedupe_array_min_size_bytes.bind(),
|
|
parsed_flags.dedupe_array_min_size_bytes.default_value(),
|
|
"Minimum size of constant arrays to deduplicate; arrays smaller "
|
|
"will not be deduplicated."),
|
|
Flag("split_tflite_lstm_inputs",
|
|
parsed_flags.split_tflite_lstm_inputs.bind(),
|
|
parsed_flags.split_tflite_lstm_inputs.default_value(),
|
|
"Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
|
|
"Ignored if the output format is not TFLite."),
|
|
Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(),
|
|
parsed_flags.quantize_to_float16.default_value(),
|
|
"Used in conjuction with post_training_quantize. Specifies that "
|
|
"the weights should be quantized to fp16 instead of the default "
|
|
"(int8)"),
|
|
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
|
|
parsed_flags.quantize_weights.default_value(),
|
|
"Deprecated. Please use --post_training_quantize instead."),
|
|
Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
|
|
parsed_flags.post_training_quantize.default_value(),
|
|
"Boolean indicating whether to quantize the weights of the "
|
|
"converted float model. Model size will be reduced and there will "
|
|
"be latency improvements (at the cost of accuracy)."),
|
|
// TODO(b/118822804): Unify the argument definition with `tflite_convert`.
|
|
// WARNING: Experimental interface, subject to change
|
|
Flag("enable_select_tf_ops", parsed_flags.enable_select_tf_ops.bind(),
|
|
parsed_flags.enable_select_tf_ops.default_value(), ""),
|
|
// WARNING: Experimental interface, subject to change
|
|
Flag("force_select_tf_ops", parsed_flags.force_select_tf_ops.bind(),
|
|
parsed_flags.force_select_tf_ops.default_value(), "")};
|
|
bool asked_for_help =
|
|
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
|
|
if (asked_for_help) {
|
|
*msg += tensorflow::Flags::Usage(argv[0], flags);
|
|
return false;
|
|
} else {
|
|
return tensorflow::Flags::Parse(argc, argv, flags);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Defines the requirements for a given flag. kUseDefault means the default
|
|
// should be used in cases where the value isn't specified by the user.
|
|
enum class FlagRequirement {
|
|
kNone,
|
|
kMustBeSpecified,
|
|
kMustNotBeSpecified,
|
|
kUseDefault,
|
|
};
|
|
|
|
// Enforces the FlagRequirements are met for a given flag.
|
|
template <typename T>
|
|
void EnforceFlagRequirement(const T& flag, const string& flag_name,
|
|
FlagRequirement requirement) {
|
|
if (requirement == FlagRequirement::kMustBeSpecified) {
|
|
QCHECK(flag.specified()) << "Missing required flag " << flag_name;
|
|
}
|
|
if (requirement == FlagRequirement::kMustNotBeSpecified) {
|
|
QCHECK(!flag.specified())
|
|
<< "Given other flags, this flag should not have been specified: "
|
|
<< flag_name;
|
|
}
|
|
}
|
|
|
|
// Gets the value from the flag if specified. Returns default if the
|
|
// FlagRequirement is kUseDefault.
|
|
template <typename T>
|
|
absl::optional<T> GetFlagValue(const Arg<T>& flag,
|
|
FlagRequirement requirement) {
|
|
if (flag.specified()) return flag.value();
|
|
if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
|
|
return absl::optional<T>();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
|
|
TocoFlags* toco_flags) {
|
|
namespace port = toco::port;
|
|
port::CheckInitGoogleIsDone("InitGoogle is not done yet");
|
|
|
|
#define READ_TOCO_FLAG(name, requirement) \
|
|
do { \
|
|
EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
|
|
auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
|
|
if (flag_value.has_value()) { \
|
|
toco_flags->set_##name(flag_value.value()); \
|
|
} \
|
|
} while (false)
|
|
|
|
#define PARSE_TOCO_FLAG(Type, name, requirement) \
|
|
do { \
|
|
EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
|
|
auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
|
|
if (flag_value.has_value()) { \
|
|
Type x; \
|
|
QCHECK(Type##_Parse(flag_value.value(), &x)) \
|
|
<< "Unrecognized " << #Type << " value " \
|
|
<< parsed_toco_flags.name.value(); \
|
|
toco_flags->set_##name(x); \
|
|
} \
|
|
} while (false)
|
|
|
|
PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
|
|
PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
|
|
PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
|
|
PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
|
|
FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone);
|
|
READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone);
|
|
|
|
if (parsed_toco_flags.force_select_tf_ops.value() &&
|
|
!parsed_toco_flags.enable_select_tf_ops.value()) {
|
|
// TODO(ycling): Consider to enforce `enable_select_tf_ops` when
|
|
// `force_select_tf_ops` is true.
|
|
LOG(WARNING) << "--force_select_tf_ops should always be used with "
|
|
"--enable_select_tf_ops.";
|
|
}
|
|
|
|
// Deprecated flag handling.
|
|
if (parsed_toco_flags.input_type.specified()) {
|
|
LOG(WARNING)
|
|
<< "--input_type is deprecated. It was an ambiguous flag that set both "
|
|
"--input_data_types and --inference_input_type. If you are trying "
|
|
"to complement the input file with information about the type of "
|
|
"input arrays, use --input_data_type. If you are trying to control "
|
|
"the quantization/dequantization of real-numbers input arrays in "
|
|
"the output file, use --inference_input_type.";
|
|
toco::IODataType input_type;
|
|
QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
|
|
&input_type));
|
|
toco_flags->set_inference_input_type(input_type);
|
|
}
|
|
if (parsed_toco_flags.input_types.specified()) {
|
|
LOG(WARNING)
|
|
<< "--input_types is deprecated. It was an ambiguous flag that set "
|
|
"both --input_data_types and --inference_input_type. If you are "
|
|
"trying to complement the input file with information about the "
|
|
"type of input arrays, use --input_data_type. If you are trying to "
|
|
"control the quantization/dequantization of real-numbers input "
|
|
"arrays in the output file, use --inference_input_type.";
|
|
std::vector<string> input_types =
|
|
absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
|
|
QCHECK(!input_types.empty());
|
|
for (int i = 1; i < input_types.size(); i++) {
|
|
QCHECK_EQ(input_types[i], input_types[0]);
|
|
}
|
|
toco::IODataType input_type;
|
|
QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
|
|
toco_flags->set_inference_input_type(input_type);
|
|
}
|
|
if (parsed_toco_flags.quantize_weights.value()) {
|
|
LOG(WARNING)
|
|
<< "--quantize_weights is deprecated. Falling back to "
|
|
"--post_training_quantize. Please switch --post_training_quantize.";
|
|
toco_flags->set_post_training_quantize(
|
|
parsed_toco_flags.quantize_weights.value());
|
|
}
|
|
if (parsed_toco_flags.quantize_weights.value()) {
|
|
if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
|
|
LOG(WARNING)
|
|
<< "--post_training_quantize quantizes a graph of inference_type "
|
|
"FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
|
|
toco_flags->set_inference_type(IODataType::FLOAT);
|
|
}
|
|
}
|
|
|
|
#undef READ_TOCO_FLAG
|
|
#undef PARSE_TOCO_FLAG
|
|
}
|
|
} // namespace toco
|