Add support for Argmax in quantize_model.
Also fix error message to say what operation is unsupported. PiperOrigin-RevId: 243547677
This commit is contained in:
parent
8ddfb1a135
commit
a41e83060f
@ -192,6 +192,7 @@ tf_cc_test(
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/argmax.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
|
||||
|
@ -131,6 +131,19 @@ TfLiteStatus GetOperatorProperty(const BuiltinOperator& op,
|
||||
property->restricted_value_on_output = {1 / 128.0, 0};
|
||||
return kTfLiteOk;
|
||||
}
|
||||
if (op == BuiltinOperator_ARG_MAX) {
|
||||
property->per_axis = false;
|
||||
property->per_axis_index = 0;
|
||||
property->arbitrary_inputs = false;
|
||||
property->input_indexes = {0};
|
||||
// ArgMax has no quantizable output, so there is nothing to do here.
|
||||
property->output_indexes = {};
|
||||
property->biases = {};
|
||||
property->restrict_same_input_output_scale = false;
|
||||
property->restriction_on_output = false;
|
||||
property->restricted_value_on_output = {};
|
||||
return kTfLiteOk;
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
@ -309,8 +309,12 @@ TfLiteStatus QuantizeWeightsInputOutput(flatbuffers::FlatBufferBuilder* builder,
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[op->opcode_index]->builtin_code;
|
||||
operator_property::OperatorProperty property;
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
operator_property::GetOperatorProperty(op_code, &property));
|
||||
if (operator_property::GetOperatorProperty(op_code, &property) ==
|
||||
kTfLiteError) {
|
||||
error_reporter->Report("Quantization not yet supported for op: %s",
|
||||
EnumNameBuiltinOperator(op_code));
|
||||
return kTfLiteError;
|
||||
}
|
||||
// Quantize weight and inputs.
|
||||
std::vector<int> input_indexes;
|
||||
if (property.arbitrary_inputs) {
|
||||
@ -323,7 +327,7 @@ TfLiteStatus QuantizeWeightsInputOutput(flatbuffers::FlatBufferBuilder* builder,
|
||||
for (const int input_idx : input_indexes) {
|
||||
if (input_idx >= op->inputs.size()) {
|
||||
error_reporter->Report(
|
||||
"Requaired input index %d is larger than the input length of op "
|
||||
"Required input index %d is larger than the input length of op "
|
||||
"%s at index %d in subgraph %d",
|
||||
input_idx, op->inputs.size(), EnumNameBuiltinOperator(op_code),
|
||||
op_idx, subgraph_idx);
|
||||
|
@ -790,6 +790,43 @@ TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
|
||||
EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
|
||||
}
|
||||
|
||||
class QuantizeArgMaxTest : public QuantizeModelTest {
|
||||
protected:
|
||||
QuantizeArgMaxTest() {
|
||||
input_model_ = ReadModel(internal::kModelWithArgMaxOp);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
|
||||
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
|
||||
TensorType_INT8, &error_reporter_);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
const auto& subgraph = model_.subgraphs[0];
|
||||
auto op = subgraph->operators[0].get();
|
||||
ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code,
|
||||
BuiltinOperator_ARG_MAX);
|
||||
|
||||
ASSERT_EQ(op->inputs.size(), 2);
|
||||
ASSERT_EQ(op->outputs.size(), 1);
|
||||
|
||||
auto float_graph = readonly_model_->subgraphs()->Get(0);
|
||||
// Verify ArgMax input is quantized.
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
|
||||
TensorType_FLOAT32);
|
||||
EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
|
||||
|
||||
// Verify ArgMax input axis should still be the same type.
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
|
||||
subgraph->tensors[op->inputs[1]].get()->type);
|
||||
|
||||
// The output of ArgMax should still be the same type.
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
|
||||
subgraph->tensors[op->outputs[0]].get()->type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
@ -41,6 +41,8 @@ const char* kFloatConcatMax5Max10Max10 = "concat.bin";
|
||||
|
||||
const char* kModelWithCustomOp = "custom_op.bin";
|
||||
|
||||
const char* kModelWithArgMaxOp = "argmax.bin";
|
||||
|
||||
int FailOnErrorReporter::Report(const char* format, va_list args) {
|
||||
char buf[1024];
|
||||
vsnprintf(buf, sizeof(buf), format, args);
|
||||
|
@ -63,6 +63,9 @@ extern const char* kFloatConcatMax5Max10Max10;
|
||||
// Test model with a custom op.
|
||||
extern const char* kModelWithCustomOp;
|
||||
|
||||
// Test model with a argmax op.
|
||||
extern const char* kModelWithArgMaxOp;
|
||||
|
||||
// An error reporter that fails on testing.
|
||||
class FailOnErrorReporter : public ErrorReporter {
|
||||
public:
|
||||
|
BIN
tensorflow/lite/tools/optimize/testdata/argmax.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/argmax.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user