Add support for Argmax in quantize_model.

Also fix error message to say what operation is unsupported.

PiperOrigin-RevId: 243547677
This commit is contained in:
Suharsh Sivakumar 2019-04-14 20:55:49 -07:00 committed by TensorFlower Gardener
parent 8ddfb1a135
commit a41e83060f
7 changed files with 63 additions and 3 deletions

View File

@ -192,6 +192,7 @@ tf_cc_test(
],
data = [
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
"//tensorflow/lite/tools/optimize:testdata/argmax.bin",
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
"//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
"//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",

View File

@ -131,6 +131,19 @@ TfLiteStatus GetOperatorProperty(const BuiltinOperator& op,
property->restricted_value_on_output = {1 / 128.0, 0};
return kTfLiteOk;
}
if (op == BuiltinOperator_ARG_MAX) {
property->per_axis = false;
property->per_axis_index = 0;
property->arbitrary_inputs = false;
property->input_indexes = {0};
// ArgMax has no quantizable output, so there is nothing to do here.
property->output_indexes = {};
property->biases = {};
property->restrict_same_input_output_scale = false;
property->restriction_on_output = false;
property->restricted_value_on_output = {};
return kTfLiteOk;
}
return kTfLiteError;
}

View File

@ -309,8 +309,12 @@ TfLiteStatus QuantizeWeightsInputOutput(flatbuffers::FlatBufferBuilder* builder,
const BuiltinOperator op_code =
model->operator_codes[op->opcode_index]->builtin_code;
operator_property::OperatorProperty property;
TF_LITE_ENSURE_STATUS(
operator_property::GetOperatorProperty(op_code, &property));
if (operator_property::GetOperatorProperty(op_code, &property) ==
kTfLiteError) {
error_reporter->Report("Quantization not yet supported for op: %s",
EnumNameBuiltinOperator(op_code));
return kTfLiteError;
}
// Quantize weight and inputs.
std::vector<int> input_indexes;
if (property.arbitrary_inputs) {
@ -323,7 +327,7 @@ TfLiteStatus QuantizeWeightsInputOutput(flatbuffers::FlatBufferBuilder* builder,
for (const int input_idx : input_indexes) {
if (input_idx >= op->inputs.size()) {
error_reporter->Report(
"Requaired input index %d is larger than the input length of op "
"Required input index %d is larger than the input length of op "
"%s at index %d in subgraph %d",
input_idx, op->inputs.size(), EnumNameBuiltinOperator(op_code),
op_idx, subgraph_idx);

View File

@ -790,6 +790,43 @@ TEST_F(QuantizeConstInputTest, VerifyConstOpInput) {
EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
}
class QuantizeArgMaxTest : public QuantizeModelTest {
protected:
QuantizeArgMaxTest() {
input_model_ = ReadModel(internal::kModelWithArgMaxOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
}
};
TEST_F(QuantizeArgMaxTest, VerifyArgMax) {
auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
TensorType_INT8, &error_reporter_);
ASSERT_EQ(kTfLiteOk, status);
const auto& subgraph = model_.subgraphs[0];
auto op = subgraph->operators[0].get();
ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code,
BuiltinOperator_ARG_MAX);
ASSERT_EQ(op->inputs.size(), 2);
ASSERT_EQ(op->outputs.size(), 1);
auto float_graph = readonly_model_->subgraphs()->Get(0);
// Verify ArgMax input is quantized.
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
TensorType_FLOAT32);
EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
// Verify ArgMax input axis should still be the same type.
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
subgraph->tensors[op->inputs[1]].get()->type);
// The output of ArgMax should still be the same type.
ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
subgraph->tensors[op->outputs[0]].get()->type);
}
} // namespace
} // namespace optimize
} // namespace tflite

View File

@ -41,6 +41,8 @@ const char* kFloatConcatMax5Max10Max10 = "concat.bin";
const char* kModelWithCustomOp = "custom_op.bin";
const char* kModelWithArgMaxOp = "argmax.bin";
int FailOnErrorReporter::Report(const char* format, va_list args) {
char buf[1024];
vsnprintf(buf, sizeof(buf), format, args);

View File

@ -63,6 +63,9 @@ extern const char* kFloatConcatMax5Max10Max10;
// Test model with a custom op.
extern const char* kModelWithCustomOp;
// Test model with a argmax op.
extern const char* kModelWithArgMaxOp;
// An error reporter that fails on testing.
class FailOnErrorReporter : public ErrorReporter {
public:

Binary file not shown.