Fixes to experimental TFLite writer library:
1. Bring up to date with builtin-ops / builtin-options. 2. Add support for a few ops previously skipped. 3. Fix related to special optional tensor index. PiperOrigin-RevId: 234188624
This commit is contained in:
parent
398fce0307
commit
8ade0b1b44
@ -112,5 +112,29 @@ inline LSHProjectionType LSHProjectionTypeToSchema(
|
||||
}
|
||||
}
|
||||
|
||||
inline MirrorPadMode MirrorPaddingModeToSchema(TfLiteMirrorPaddingMode mode) {
|
||||
switch (mode) {
|
||||
case kTfLiteMirrorPaddingUnknown:
|
||||
return MirrorPadMode_REFLECT; // TODO(aselle): consider an error
|
||||
case kTfLiteMirrorPaddingReflect:
|
||||
return MirrorPadMode_REFLECT;
|
||||
case kTfLiteMirrorPaddingSymmetric:
|
||||
return MirrorPadMode_SYMMETRIC;
|
||||
}
|
||||
}
|
||||
|
||||
inline CombinerType CombinerTypeToSchema(TfLiteCombinerType type) {
|
||||
switch (type) {
|
||||
case kTfLiteCombinerTypeSum:
|
||||
return CombinerType_SUM;
|
||||
case kTfLiteCombinerTypeMean:
|
||||
return CombinerType_MEAN;
|
||||
case kTfLiteCombinerTypeSqrtn:
|
||||
return CombinerType_SQRTN;
|
||||
}
|
||||
}
|
||||
|
||||
// int
|
||||
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
|
||||
|
@ -22,54 +22,59 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
// This is generated by grepping
|
||||
// cat third_party/tensorflow/lite/builtin_op_data.h
|
||||
//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
|
||||
static const char* param_structs[] = {"TfLiteConvParams",
|
||||
"TfLitePoolParams",
|
||||
"TfLiteDepthwiseConvParams",
|
||||
"TfLiteSVDFParams",
|
||||
"TfLiteRNNParams",
|
||||
"TfLiteSequenceRNNParams",
|
||||
"TfLiteFullyConnectedParams",
|
||||
"TfLiteLSHProjectionParams",
|
||||
"TfLiteSoftmaxParams",
|
||||
"TfLiteConcatenationParams",
|
||||
"TfLiteAddParams",
|
||||
"TfLiteSpaceToBatchNDParams",
|
||||
// cat third_party/tensorflow/lite/c/builtin_op_data.h | grep "^} TfLite" |
|
||||
// sed 's/^} \(TfLite.*\)Params;/\1Params/g' | grep -v "^}" | sed
|
||||
// 's/\(.*\)/"\1",/g' | sort
|
||||
static const char* param_structs[] = {"TfLiteAddParams",
|
||||
"TfLiteArgMaxParams",
|
||||
"TfLiteArgMinParams",
|
||||
"TfLiteBatchToSpaceNDParams",
|
||||
"TfLiteMulParams",
|
||||
"TfLiteSubParams",
|
||||
"TfLiteBidirectionalSequenceLSTMParams",
|
||||
"TfLiteBidirectionalSequenceRNNParams",
|
||||
"TfLiteCastParams",
|
||||
"TfLiteConcatenationParams",
|
||||
"TfLiteConvParams",
|
||||
"TfLiteDepthwiseConvParams",
|
||||
"TfLiteDivParams",
|
||||
"TfLiteEmbeddingLookupSparseParams",
|
||||
"TfLiteFakeQuantParams",
|
||||
"TfLiteFullyConnectedParams",
|
||||
"TfLiteGatherParams",
|
||||
"TfLiteL2NormParams",
|
||||
"TfLiteLeakyReluParams",
|
||||
"TfLiteLocalResponseNormParams",
|
||||
"TfLiteLSHProjectionParams",
|
||||
"TfLiteLSTMParams",
|
||||
"TfLiteResizeBilinearParams",
|
||||
"TfLiteResizeNearestNeighborParams",
|
||||
"TfLiteMirrorPaddingParams",
|
||||
"TfLiteMulParams",
|
||||
"TfLiteOneHotParams",
|
||||
"TfLitePackParams",
|
||||
"TfLitePadParams",
|
||||
"TfLitePadV2Params",
|
||||
"TfLiteReshapeParams",
|
||||
"TfLiteSkipGramParams",
|
||||
"TfLiteSpaceToDepthParams",
|
||||
"TfLiteCastParams",
|
||||
"TfLiteEmbeddingLookupSparseParams",
|
||||
"TfLiteGatherParams",
|
||||
"TfLiteTransposeParams",
|
||||
"TfLitePoolParams",
|
||||
"TfLiteReducerParams",
|
||||
"TfLiteReshapeParams",
|
||||
"TfLiteResizeBilinearParams",
|
||||
"TfLiteResizeNearestNeighborParams",
|
||||
"TfLiteRNNParams",
|
||||
"TfLiteSequenceRNNParams",
|
||||
"TfLiteShapeParams",
|
||||
"TfLiteSkipGramParams",
|
||||
"TfLiteSoftmaxParams",
|
||||
"TfLiteSpaceToBatchNDParams",
|
||||
"TfLiteSpaceToDepthParams",
|
||||
"TfLiteSparseToDenseParams",
|
||||
"TfLiteSplitParams",
|
||||
"TfLiteSplitVParams",
|
||||
"TfLiteSqueezeParams",
|
||||
"TfLiteStridedSliceParams",
|
||||
"TfLiteArgMaxParams",
|
||||
"TfLiteArgMinParams",
|
||||
"TfLiteSubParams",
|
||||
"TfLiteSVDFParams",
|
||||
"TfLiteTransposeConvParams",
|
||||
"TfLiteSparseToDenseParams",
|
||||
"TfLiteShapeParams",
|
||||
"TfLiteFakeQuantParams",
|
||||
"TfLitePackParams",
|
||||
"TfLiteOneHotParams",
|
||||
"TfLiteLeakyReluParams",
|
||||
"TfLiteMirrorPaddingParams",
|
||||
"TfLiteTransposeParams",
|
||||
"TfLiteUnidirectionalSequenceLSTMParams",
|
||||
"TfLiteUniqueParams",
|
||||
"TfLiteUnpackParams",
|
||||
nullptr};
|
||||
} // namespace
|
||||
|
||||
@ -142,7 +147,6 @@ class OpOptionData {
|
||||
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_MIN"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_ANY"] = "ReducerOptions";
|
||||
op_to_option_["UNPACK"] = "";
|
||||
op_to_option_["SUM"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_PROD"] = "ReducerOptions";
|
||||
@ -151,35 +155,30 @@ class OpOptionData {
|
||||
op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
|
||||
op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
|
||||
op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
|
||||
op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
|
||||
op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
|
||||
op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
|
||||
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
|
||||
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
|
||||
op_to_option_["MIRROR_PAD"] = ""; // TODO(karimnosseir): MirrorPadOptions.
|
||||
op_to_option_["UNIQUE"] = ""; // TODO(karimnosseir): UniqueOptions.
|
||||
// Manually specified mappings between ops and options (none)
|
||||
op_to_option_["EMBEDDING_LOOKUP"] =
|
||||
""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["MAXIMUM"] = "MaximumMinimumOptions";
|
||||
op_to_option_["MINIMUM"] = "MaximumMinimumOptions";
|
||||
op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
|
||||
|
||||
// Manually specified mappings between ops to "none" options -- these are
|
||||
// ops without a corresponding Options message in schema as yet. If these
|
||||
// options do get assigned an Options message in future, they need to be
|
||||
// updated here as well.
|
||||
op_to_option_["EMBEDDING_LOOKUP"] = "";
|
||||
op_to_option_["FLOOR"] = "";
|
||||
op_to_option_["CEIL"] = "";
|
||||
op_to_option_["HASHTABLE_LOOKUP"] =
|
||||
""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["HASHTABLE_LOOKUP"] = "";
|
||||
op_to_option_["LOGISTIC"] = "";
|
||||
op_to_option_["RELU"] = "";
|
||||
op_to_option_["RELU_N1_TO_1"] = "";
|
||||
op_to_option_["RELU6"] = "";
|
||||
op_to_option_["TANH"] = "";
|
||||
op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["PRELU"] = "";
|
||||
op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
|
||||
op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
|
||||
op_to_option_["SIN"] = "";
|
||||
op_to_option_["LOG"] = "";
|
||||
op_to_option_["SQRT"] = "";
|
||||
op_to_option_["RSQRT"] = "";
|
||||
op_to_option_["Rank"] = "";
|
||||
|
||||
// TODO(aselle): These are undesirable hacks. Consider changing C structs
|
||||
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
|
||||
@ -187,6 +186,7 @@ class OpOptionData {
|
||||
option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
|
||||
option_to_struct_["LocalResponseNormalizationOptions"] =
|
||||
"TfLiteLocalResponseNormParams";
|
||||
option_to_struct_["MirrorPadOptions"] = "TfLiteMirrorPaddingParams";
|
||||
// Now for every op, try to find an option.
|
||||
bool fatal = false;
|
||||
for (auto op_name : ops_) {
|
||||
@ -226,13 +226,15 @@ class OpOptionData {
|
||||
if (!param_struct_found) {
|
||||
std::cerr << "Failed to get param struct for option " << option_name
|
||||
<< std::endl;
|
||||
fatal = true;
|
||||
} else {
|
||||
option_to_struct_.insert(std::make_pair(option_name, params_guess));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fatal) {
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -243,16 +245,28 @@ class OpOptionData {
|
||||
option_to_type_function_;
|
||||
};
|
||||
|
||||
void GenerateImportForResizeBilinearOp(FILE* fp) {
|
||||
fprintf(fp,
|
||||
" case BuiltinOperator_RESIZE_BILINEAR: {\n"
|
||||
" const auto* params = reinterpret_cast<const "
|
||||
"TfLiteResizeBilinearParams*>(builtin_op_data);\n"
|
||||
" auto union_type = CreateResizeBilinearOptions(*fbb, "
|
||||
"params->align_corners).Union();\n"
|
||||
" return std::make_pair(BuiltinOptions_ResizeBilinearOptions, "
|
||||
"union_type);\n"
|
||||
" }\n break;\n");
|
||||
}
|
||||
|
||||
void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
const std::string& option_name,
|
||||
const std::string& option_type,
|
||||
const flatbuffers::TypeTable* options,
|
||||
const std::string& struct_name) {
|
||||
// Skip tricky ones for now
|
||||
if (struct_name == "TfLiteResizeBilinearParams") return;
|
||||
if (struct_name == "TfLiteSqueezeParams") return;
|
||||
if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
|
||||
if (struct_name == "TfLiteReshapeParams") return;
|
||||
// Special-case ResizeBilinear which has some deprecated fields.
|
||||
if (struct_name == "TfLiteResizeBilinearParams") {
|
||||
GenerateImportForResizeBilinearOp(fp);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
|
||||
fprintf(fp,
|
||||
@ -262,6 +276,9 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
|
||||
for (size_t i = 0; i < options->num_elems; i++) {
|
||||
std::string elem_name = options->names[i];
|
||||
bool is_int_vector = false;
|
||||
std::string vector_name = elem_name;
|
||||
std::string vector_size;
|
||||
// TODO(aselle): Irregular naming in builtins
|
||||
if (elem_name == "fused_activation_function")
|
||||
elem_name = "activation";
|
||||
@ -273,8 +290,26 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
elem_name = "dilation_height_factor";
|
||||
else if (elem_name == "dilation_w_factor")
|
||||
elem_name = "dilation_width_factor";
|
||||
else if (elem_name == "new_shape")
|
||||
elem_name = "shape";
|
||||
else if (elem_name == "idx_out_type")
|
||||
elem_name = "index_out_type";
|
||||
|
||||
// Vector fields treated specially.
|
||||
if (elem_name == "new_shape") {
|
||||
is_int_vector = true;
|
||||
vector_name = "shape";
|
||||
vector_size = "num_dimensions";
|
||||
} else if (elem_name == "squeeze_dims") {
|
||||
is_int_vector = true;
|
||||
vector_size = "num_squeeze_dims";
|
||||
}
|
||||
|
||||
if (is_int_vector) {
|
||||
fprintf(fp,
|
||||
" auto val%zu = fbb->CreateVector("
|
||||
"std::vector<int>(params->%s, params->%s + params->%s));\n",
|
||||
i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
flatbuffers::TypeCode code = options->type_codes[i];
|
||||
auto contained_type = code.sequence_ref != -1
|
||||
@ -293,6 +328,10 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
mapper = "LSTMKernelTypeToSchema";
|
||||
} else if (contained_type == LSHProjectionTypeTypeTable) {
|
||||
mapper = "LSHProjectionTypeToSchema";
|
||||
} else if (contained_type == MirrorPadModeTypeTable) {
|
||||
mapper = "MirrorPaddingModeToSchema";
|
||||
} else if (contained_type == CombinerTypeTypeTable) {
|
||||
mapper = "CombinerTypeToSchema";
|
||||
}
|
||||
|
||||
fprintf(fp,
|
||||
|
@ -219,6 +219,11 @@ std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
|
||||
std::vector<int> output;
|
||||
output.reserve(input.size());
|
||||
for (int x : input) {
|
||||
// Special value representing an optional tensor which is not present.
|
||||
if (x == -1) {
|
||||
output.push_back(x);
|
||||
continue;
|
||||
}
|
||||
if (tensor_to_written_tensor_[x] != -1) {
|
||||
output.push_back(tensor_to_written_tensor_[x]);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user