Properly unittest all supported conversion modes.
PiperOrigin-RevId: 217799124
This commit is contained in:
parent
79bff06f77
commit
6e45e6ae21
@ -30,6 +30,7 @@ using ::testing::ElementsAre;
|
||||
|
||||
class ExportTest : public ::testing::Test {
|
||||
protected:
|
||||
void ResetOperators() { input_model_.operators.clear(); }
|
||||
void AddTensorsByName(std::initializer_list<string> names) {
|
||||
for (const string& name : names) {
|
||||
input_model_.GetOrCreateArray(name);
|
||||
@ -93,7 +94,11 @@ class ExportTest : public ::testing::Test {
|
||||
std::vector<string> names;
|
||||
|
||||
string result;
|
||||
if (!Export(input_model_, &result, params).ok()) return names;
|
||||
auto status = Export(input_model_, &result, params);
|
||||
if (!status.ok()) {
|
||||
LOG(INFO) << status.error_message();
|
||||
return names;
|
||||
}
|
||||
|
||||
auto* model = ::tflite::GetModel(result.data());
|
||||
|
||||
@ -166,39 +171,6 @@ TEST_F(ExportTest, Export) {
|
||||
EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, ExportWithOptions) {
|
||||
// We have three main types of operators:
|
||||
// 1) first-class TOCO/TF Lite operators, which are exported as TF Lite
|
||||
// builtins
|
||||
// 2) operators that are not recognized by TOCO but are known to the
|
||||
// standard TensorFlow runtime. Those should be exported as FLEX
|
||||
// operators, when requested. Otherwise they are custom TF Lite ops.
|
||||
// 3) operators that are not recognized and should be exported as custom
|
||||
// TF Lite ops. There are two subtypes: ops that would be recognized
|
||||
// by a custom TensorFlow runtime, and "fake" ops that have only a
|
||||
// TF Lite impl.
|
||||
AddOperatorsByName({"Add", "ResizeNearestNeighbor"});
|
||||
|
||||
ExportParams params;
|
||||
params.allow_custom_ops = false;
|
||||
params.allow_flex_ops = false;
|
||||
params.quantize_weights = false;
|
||||
|
||||
// Conversion fails because ResizeNearestNeighbor is unknown.
|
||||
EXPECT_THAT(ExportAndSummarizeOperators(params), ElementsAre());
|
||||
|
||||
// ResizeNearestNeighbor is treated as a simple custom op (#3 above).
|
||||
params.allow_custom_ops = true;
|
||||
EXPECT_THAT(ExportAndSummarizeOperators(params),
|
||||
ElementsAre("builtin:ADD", "custom:ResizeNearestNeighbor"));
|
||||
|
||||
// ResizeNearestNeighbor is recognized as a TensorFlow (Flex) op.
|
||||
params.allow_custom_ops = true;
|
||||
params.allow_flex_ops = true;
|
||||
EXPECT_THAT(ExportAndSummarizeOperators(params),
|
||||
ElementsAre("builtin:ADD", "custom:FlexResizeNearestNeighbor"));
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, QuantizeWeights) {
|
||||
// Sanity check for quantize_weights parameter.
|
||||
BuildQuantizableTestModel();
|
||||
@ -213,6 +185,101 @@ TEST_F(ExportTest, QuantizeWeights) {
|
||||
EXPECT_LT(quantized_result.size(), unquantized_result.size());
|
||||
}
|
||||
|
||||
class OpSetsTest : public ExportTest {
|
||||
public:
|
||||
enum OpSet { kTfLiteBuiltins, kSelectTfOps, kCustomOps };
|
||||
|
||||
void SetAllowedOpSets(std::initializer_list<OpSet> sets) {
|
||||
import_all_ops_as_unsupported_ = true;
|
||||
params_.allow_custom_ops = false;
|
||||
params_.allow_flex_ops = false;
|
||||
params_.quantize_weights = false;
|
||||
|
||||
for (OpSet i : sets) {
|
||||
switch (i) {
|
||||
case kTfLiteBuiltins:
|
||||
import_all_ops_as_unsupported_ = false;
|
||||
break;
|
||||
case kSelectTfOps:
|
||||
params_.allow_flex_ops = true;
|
||||
break;
|
||||
case kCustomOps:
|
||||
params_.allow_custom_ops = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<string> ImportExport(std::initializer_list<string> op_names) {
|
||||
ResetOperators();
|
||||
if (!import_all_ops_as_unsupported_) {
|
||||
AddOperatorsByName(op_names);
|
||||
} else {
|
||||
for (const string& name : op_names) {
|
||||
auto* op = new TensorFlowUnsupportedOperator;
|
||||
op->tensorflow_op = name;
|
||||
input_model_.operators.emplace_back(op);
|
||||
}
|
||||
}
|
||||
return ExportAndSummarizeOperators(params_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool import_all_ops_as_unsupported_;
|
||||
ExportParams params_;
|
||||
};
|
||||
|
||||
TEST_F(OpSetsTest, BuiltinsOnly) {
|
||||
// --target_op_set=TFLITE_BUILTINS
|
||||
SetAllowedOpSets({kTfLiteBuiltins});
|
||||
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||
ElementsAre());
|
||||
EXPECT_THAT(ImportExport({"Add"}), ElementsAre("builtin:ADD"));
|
||||
|
||||
// --target_op_set=TFLITE_BUILTINS --allow_custom_ops
|
||||
SetAllowedOpSets({kTfLiteBuiltins, kCustomOps});
|
||||
EXPECT_THAT(
|
||||
ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||
ElementsAre("builtin:ADD", "custom:AdjustHue", "custom:UnrollAndFold"));
|
||||
}
|
||||
|
||||
TEST_F(OpSetsTest, TfSelectOnly) {
|
||||
// --target_op_set=SELECT_TF_OPS
|
||||
SetAllowedOpSets({kSelectTfOps});
|
||||
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||
ElementsAre());
|
||||
|
||||
// TODO(b/117845348): Add should be recognized as a flex op, which should be
|
||||
// OK even if we don't specify --allow_custom_ops
|
||||
// EXPECT_THAT(ImportExport({"Add"}), ElementsAre("custom:FlexAdd"));
|
||||
EXPECT_THAT(ImportExport({"Add"}), ElementsAre());
|
||||
|
||||
// --target_op_set=SELECT_TF_OPS --allow_custom_ops
|
||||
SetAllowedOpSets({kSelectTfOps, kCustomOps});
|
||||
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||
ElementsAre("custom:FlexAdd", "custom:FlexAdjustHue",
|
||||
"custom:UnrollAndFold"));
|
||||
}
|
||||
|
||||
TEST_F(OpSetsTest, BuiltinsAndTfSelect) {
|
||||
// --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS
|
||||
SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps});
|
||||
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||
ElementsAre());
|
||||
|
||||
// TODO(b/117845348): AdjustHue should be recognized as a flex op,
|
||||
// which should be OK even if we don't specify --allow_custom_ops
|
||||
// EXPECT_THAT(ImportExport({"Add", "AdjustHue"}),
|
||||
// ElementsAre("builtin:ADD", "custom:FlexAdjustHue"));
|
||||
EXPECT_THAT(ImportExport({"Add", "AdjustHue"}), ElementsAre());
|
||||
|
||||
// --target_op_set=TFLITE_BUILTINS,SELECT_TF_OPS --allow_custom_ops
|
||||
SetAllowedOpSets({kTfLiteBuiltins, kSelectTfOps, kCustomOps});
|
||||
EXPECT_THAT(ImportExport({"Add", "AdjustHue", "UnrollAndFold"}),
|
||||
ElementsAre("builtin:ADD", "custom:FlexAdjustHue",
|
||||
"custom:UnrollAndFold"));
|
||||
}
|
||||
|
||||
// This test is based on a hypothetical scenario that dilation is supported
|
||||
// only in Conv version 2. So Toco populates version=1 when dialation
|
||||
// parameters are all 1, and version=2 otehrwise.
|
||||
|
Loading…
Reference in New Issue
Block a user