Properly unittest all supported conversion modes.
PiperOrigin-RevId: 217799124
This commit is contained in:
parent
79bff06f77
commit
6e45e6ae21
@ -30,6 +30,7 @@ using ::testing::ElementsAre;
|
|||||||
|
|
||||||
class ExportTest : public ::testing::Test {
|
class ExportTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
|
void ResetOperators() { input_model_.operators.clear(); }
|
||||||
void AddTensorsByName(std::initializer_list<string> names) {
|
void AddTensorsByName(std::initializer_list<string> names) {
|
||||||
for (const string& name : names) {
|
for (const string& name : names) {
|
||||||
input_model_.GetOrCreateArray(name);
|
input_model_.GetOrCreateArray(name);
|
||||||
@ -93,7 +94,11 @@ class ExportTest : public ::testing::Test {
|
|||||||
std::vector<string> names;
|
std::vector<string> names;
|
||||||
|
|
||||||
string result;
|
string result;
|
||||||
if (!Export(input_model_, &result, params).ok()) return names;
|
auto status = Export(input_model_, &result, params);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(INFO) << status.error_message();
|
||||||
|
return names;
|
||||||
|
}
|
||||||
|
|
||||||
auto* model = ::tflite::GetModel(result.data());
|
auto* model = ::tflite::GetModel(result.data());
|
||||||
|
|
||||||
@ -166,39 +171,6 @@ TEST_F(ExportTest, Export) {
|
|||||||
EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
|
EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ExportTest, ExportWithOptions) {
|
|
||||||
// We have three main types of operators:
|
|
||||||
// 1) first-class TOCO/TF Lite operators, which are exported as TF Lite
|
|
||||||
// builtins
|
|
||||||
// 2) operators that are not recognized by TOCO but are known to the
|
|
||||||
// standard TensorFlow runtime. Those should be exported as FLEX
|
|
||||||
// operators, when requested. Otherwise they are custom TF Lite ops.
|
|
||||||
// 3) operators that are not recognized and should be exported as custom
|
|
||||||
// TF Lite ops. There are two subtypes: ops that would be recognized
|
|
||||||
// by a custom TensorFlow runtime, and "fake" ops that have only a
|
|
||||||
// TF Lite impl.
|
|
||||||
AddOperatorsByName({"Add", "ResizeNearestNeighbor"});
|
|
||||||
|
|
||||||
ExportParams params;
|
|
||||||
params.allow_custom_ops = false;
|
|
||||||
params.allow_flex_ops = false;
|
|
||||||
params.quantize_weights = false;
|
|
||||||
|
|
||||||
// Conversion fails because ResizeNearestNeighbor is unknown.
|
|
||||||
EXPECT_THAT(ExportAndSummarizeOperators(params), ElementsAre());
|
|
||||||
|
|
||||||
// ResizeNearestNeighbor is treated as a simple custom op (#3 above).
|
|
||||||
params.allow_custom_ops = true;
|
|
||||||
EXPECT_THAT(ExportAndSummarizeOperators(params),
|
|
||||||
ElementsAre("builtin:ADD", "custom:ResizeNearestNeighbor"));
|
|
||||||
|
|
||||||
// ResizeNearestNeighbor is recognized as a TensorFlow (Flex) op.
|
|
||||||
params.allow_custom_ops = true;
|
|
||||||
params.allow_flex_ops = true;
|
|
||||||
EXPECT_THAT(ExportAndSummarizeOperators(params),
|
|
||||||
ElementsAre("builtin:ADD", "custom:FlexResizeNearestNeighbor"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ExportTest, QuantizeWeights) {
|
TEST_F(ExportTest, QuantizeWeights) {
|
||||||
// Sanity check for quantize_weights parameter.
|
// Sanity check for quantize_weights parameter.
|
||||||
BuildQuantizableTestModel();
|
BuildQuantizableTestModel();
|
||||||
@ -213,6 +185,101 @@ TEST_F(ExportTest, QuantizeWeights) {
|
|||||||
EXPECT_LT(quantized_result.size(), unquantized_result.size());
|
EXPECT_LT(quantized_result.size(), unquantized_result.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class OpSetsTest : public ExportTest {
|
||||||
|
public:
|
||||||
|
enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
|
||||||
|
|
||||||
|
void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
|
||||||
|
import_all_ops_as_unsupported_ = true;
|
||||||
|
params_.allow_custom_ops = false;
|
||||||
|
params_.allow_flex_ops = false;
|
||||||
|
params_.quantize_weights = false;
|
||||||
|
|
||||||
|
for (OpSet i : sets) {
|
||||||
|
switch (i) {
|
||||||
|
case kTfLiteBuiltins:
|
||||||
|
import_all_ops_as_unsupported_ = false;
|
||||||
|
break;
|
||||||
|
case kSelectTfOps:
|
||||||
|
params_.allow_flex_ops = true;
|
||||||
|
break;
|
||||||
|
case kCustomOps:
|
||||||
|
params_.allow_custom_ops = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> ImportExport(std::initializer_list<string> op_names) {
|
||||||
|
ResetOperators();
|
||||||
|
if (!import_all_ops_as_unsupported_) {
|
||||||
|
AddOperatorsByName(op_names);
|
||||||
|
} else {
|
||||||
|
for (const string& name : op_names) {
|
||||||
|
auto* op = new TensorFlowUnsupportedOperator;
|
||||||
|
op->tensorflow_op = name;
|
||||||
|
input_model_.operators.emplace_back(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ExportAndSummarizeOperators(params_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool import_all_ops_as_unsupported_;
|
||||||
|
ExportParams params_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(OpSetsTest, BuiltinsOnly) {
|
||||||
|
// --target_op_set=TFLITE_BUILTINS
|
||||||
|
SetAllowedOpSets({kTfLiteBuiltins});
|
||||||
|
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||||
|
ElementsAre());
|
||||||
|
EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
|
||||||
|
|
||||||
|
// --target_op_set=TFLITE_BUILTINS --allow_custom_ops
|
||||||
|
SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
|
||||||
|
EXPECT_THAT(
|
||||||
|
ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||||
|
ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:UnrollAndFold"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpSetsTest, TfSelectOnly) {
|
||||||
|
// --target_op_set=SELECT_TF_OPS
|
||||||
|
SetAllowedOpSets({kSelectTfOps});
|
||||||
|
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||||
|
ElementsAre());
|
||||||
|
|
||||||
|
// TODO(b/117845348): Add should be recognized as a flex op, which should be
|
||||||
|
// OK even if we don't specify --allow_custom_ops
|
||||||
|
// EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
|
||||||
|
EXPECT_THAT(ImportExport({"Add"}), ElementsAre());
|
||||||
|
|
||||||
|
// --target_op_set=SELECT_TF_OPS --allow_custom_ops
|
||||||
|
SetAllowedOpSets({kSelectTfOps, kCustomOps});
|
||||||
|
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||||
|
ElementsAre("custom:FlexAdd", "custom:FlexAdjustHue",
|
||||||
|
"custom:UnrollAndFold"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
|
||||||
|
// --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
|
||||||
|
SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
|
||||||
|
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||||
|
ElementsAre());
|
||||||
|
|
||||||
|
// TODO(b/117845348): AdjustHue should be recognized as a flex op,
|
||||||
|
// which should be OK even if we don't specify --allow_custom_ops
|
||||||
|
// EXPECT_THAT(ImportExport({"Add", "AdjustHue"}),
|
||||||
|
// ElementsAre("builtin:ADD", "custom:FlexAdjustHue"));
|
||||||
|
EXPECT_THAT(ImportExport({"Add", "AdjustHue"}), ElementsAre());
|
||||||
|
|
||||||
|
// --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
|
||||||
|
SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
|
||||||
|
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||||
|
ElementsAre("builtin:ADD", "custom:FlexAdjustHue",
|
||||||
|
"custom:UnrollAndFold"));
|
||||||
|
}
|
||||||
|
|
||||||
// This test is based on a hypothetical scenario that dilation is supported
|
// This test is based on a hypothetical scenario that dilation is supported
|
||||||
// only in Conv version 2. So Toco populates version=1 when dialation
|
// only in Conv version 2. So Toco populates version=1 when dialation
|
||||||
// parameters are all 1, and version=2 otehrwise.
|
// parameters are all 1, and version=2 otehrwise.
|
||||||
|
Loading…
Reference in New Issue
Block a user