Make quantization properties per-tensor.

PiperOrigin-RevId: 251876877
This commit is contained in:
Suharsh Sivakumar 2019-06-06 10:24:51 -07:00 committed by TensorFlower Gardener
parent 651201e51a
commit a372bb0e9d
4 changed files with 173 additions and 132 deletions

View File

@ -21,18 +21,18 @@ OperatorProperty GetOperatorProperty(const BuiltinOperator& op) {
OperatorProperty property; OperatorProperty property;
switch (op) { switch (op) {
case BuiltinOperator_ADD: case BuiltinOperator_ADD:
property.input_indexes = {0, 1}; property.inputs = {{0, {}}, {1, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_ARG_MAX: case BuiltinOperator_ARG_MAX:
property.input_indexes = {0}; property.inputs = {{0, {}}};
// ArgMax has no quantizable output. // ArgMax has no quantizable output.
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_AVERAGE_POOL_2D: case BuiltinOperator_AVERAGE_POOL_2D:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
@ -40,175 +40,196 @@ OperatorProperty GetOperatorProperty(const BuiltinOperator& op) {
case BuiltinOperator_SPACE_TO_BATCH_ND: case BuiltinOperator_SPACE_TO_BATCH_ND:
case BuiltinOperator_SPACE_TO_DEPTH: case BuiltinOperator_SPACE_TO_DEPTH:
// We skip inputs 1 and 2 since they aren't real valued (they are shapes). // We skip inputs 1 and 2 since they aren't real valued (they are shapes).
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_CONCATENATION: case BuiltinOperator_CONCATENATION:
property.arbitrary_inputs = true; property.arbitrary_inputs = true;
property.input_indexes = {}; property.outputs = {{0, {}}};
property.output_indexes = {0};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_CONV_2D: case BuiltinOperator_CONV_2D: {
property.per_axis = true; TensorProperty tensor_property;
property.per_axis_index = 0; tensor_property.per_axis = true;
property.input_indexes = {0, 1}; tensor_property.per_axis_index = 0;
property.output_indexes = {0}; tensor_property.symmetric = true;
property.inputs = {{0, {}}, {1, tensor_property}};
property.outputs = {{0, {}}};
property.biases = {2}; property.biases = {2};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_DEPTHWISE_CONV_2D: }
property.per_axis = true; case BuiltinOperator_DEPTHWISE_CONV_2D: {
property.per_axis_index = 3; TensorProperty tensor_property;
property.input_indexes = {0, 1}; tensor_property.per_axis = true;
property.output_indexes = {0}; tensor_property.per_axis_index = 3;
tensor_property.symmetric = true;
property.inputs = {
{0, {}},
{1, tensor_property},
};
property.outputs = {{0, {}}};
property.biases = {2}; property.biases = {2};
property.version = 3; property.version = 3;
break; break;
}
case BuiltinOperator_EQUAL: case BuiltinOperator_EQUAL:
case BuiltinOperator_NOT_EQUAL: case BuiltinOperator_NOT_EQUAL:
case BuiltinOperator_GREATER: case BuiltinOperator_GREATER:
case BuiltinOperator_GREATER_EQUAL: case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_LESS: case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL: case BuiltinOperator_LESS_EQUAL:
property.input_indexes = {0, 1}; property.inputs = {{0, {}}, {1, {}}};
// Comparisons have no quantizable outputs. // Comparisons have no quantizable outputs.
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_FULLY_CONNECTED: case BuiltinOperator_FULLY_CONNECTED: {
property.input_indexes = {0, 1}; TensorProperty tensor_property;
property.output_indexes = {0}; tensor_property.symmetric = true;
property.inputs = {{0, {}}, {1, tensor_property}};
property.outputs = {{0, {}}};
property.biases = {2}; property.biases = {2};
property.version = 4; property.version = 4;
break; break;
}
case BuiltinOperator_GATHER: case BuiltinOperator_GATHER:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_LOG_SOFTMAX: {
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0};
// LogSoftmax requires output with 16/256 as scale and 127 as zero point. // LogSoftmax requires output with 16/256 as scale and 127 as zero point.
property.restriction_on_output = true; TensorProperty tensor_property;
property.restricted_value_on_output = {16.0 / 256.0, 127}; tensor_property.restriction = true;
tensor_property.restricted_value = {16.0 / 256.0, 127};
property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_LOGISTIC: }
property.input_indexes = {0}; case BuiltinOperator_LOGISTIC: {
property.output_indexes = {0}; property.inputs = {{0, {}}};
// Logistic requires output with 1/256 as scale and -128 as zero point. // Logistic requires output with 1/256 as scale and -128 as zero point.
property.restriction_on_output = true; TensorProperty tensor_property;
property.restricted_value_on_output = {1 / 256.0, -128}; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 256.0, -128};
property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_L2_NORMALIZATION: }
property.input_indexes = {0}; case BuiltinOperator_L2_NORMALIZATION: {
property.output_indexes = {0}; property.inputs = {{0, {}}};
// L2 Norm requires output with 1/128 as scale and 0 as zero point. // L2 Norm requires output with 1/128 as scale and 0 as zero point.
property.restriction_on_output = true; TensorProperty tensor_property;
property.restricted_value_on_output = {1 / 128.0, 0}; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 128.0, 0};
property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
}
case BuiltinOperator_MAX_POOL_2D: case BuiltinOperator_MAX_POOL_2D:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_MAXIMUM: case BuiltinOperator_MAXIMUM:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_MEAN: case BuiltinOperator_MEAN:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_MINIMUM: case BuiltinOperator_MINIMUM:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_MUL: case BuiltinOperator_MUL:
property.input_indexes = {0, 1}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_PAD: case BuiltinOperator_PAD:
case BuiltinOperator_PADV2: case BuiltinOperator_PADV2:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_QUANTIZE: case BuiltinOperator_QUANTIZE:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.version = 1; property.version = 1;
break; break;
case BuiltinOperator_RESHAPE: case BuiltinOperator_RESHAPE:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 1; property.version = 1;
break; break;
case BuiltinOperator_RESIZE_BILINEAR: case BuiltinOperator_RESIZE_BILINEAR:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_SHAPE: case BuiltinOperator_SHAPE:
property.input_indexes = {0}; property.inputs = {{0, {}}};
// Shape has no quantizable output. // Shape has no quantizable output.
property.version = 1; property.version = 1;
break; break;
case BuiltinOperator_SLICE: case BuiltinOperator_SLICE:
// We skip inputs 1 and 2 since they aren't real valued (they are the // We skip inputs 1 and 2 since they aren't real valued (they are the
// index and size). // index and size).
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_SQUEEZE: case BuiltinOperator_SQUEEZE:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 1; property.version = 1;
break; break;
case BuiltinOperator_SOFTMAX: case BuiltinOperator_SOFTMAX: {
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0};
// Softmax requires output with 1/256 as scale and -128 as zero point. // Softmax requires output with 1/256 as scale and -128 as zero point.
property.restriction_on_output = true; TensorProperty tensor_property;
property.restricted_value_on_output = {1 / 256.0, -128}; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 256.0, -128};
property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
}
case BuiltinOperator_SUB: case BuiltinOperator_SUB:
property.input_indexes = {0, 1}; property.inputs = {{0, {}}, {1, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.version = 2; property.version = 2;
break; break;
case BuiltinOperator_TANH: case BuiltinOperator_TANH: {
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0};
// Tanh requires output with 1/128 as scale and 0 as zero point. // Tanh requires output with 1/128 as scale and 0 as zero point.
property.restriction_on_output = true; TensorProperty tensor_property;
property.restricted_value_on_output = {1 / 128.0, 0}; tensor_property.restriction = true;
tensor_property.restricted_value = {1 / 128.0, 0};
property.outputs = {{0, tensor_property}};
property.version = 2; property.version = 2;
break; break;
}
case BuiltinOperator_TRANSPOSE: case BuiltinOperator_TRANSPOSE:
property.input_indexes = {0}; property.inputs = {{0, {}}};
property.output_indexes = {0}; property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true; property.restrict_same_input_output_scale = true;
property.version = 2; property.version = 2;
break; break;

View File

@ -22,30 +22,34 @@ namespace tflite {
namespace optimize { namespace optimize {
namespace operator_property { namespace operator_property {
struct OperatorProperty { struct TensorProperty {
// Is a quantized operations currently supported. // per_axis also implies symmetric currently.
bool quantizable = true;
// Per axis.
bool per_axis = false; bool per_axis = false;
// TODO(jianlijianli): remove dimension index and read it from tensor instead. // TODO(jianlijianli): remove dimension index and read it from tensor instead.
int per_axis_index = 0; int per_axis_index = 0;
bool symmetric = false;
// Constraints.
bool restriction = false;
// scale/zero_point hardcoded.
std::pair<float, int> restricted_value = {0.0, 0};
};
struct OperatorProperty {
// Is a quantized operations currently supported.
bool quantizable = true;
// Op has arbitrary number of inputs, such as concat. // Op has arbitrary number of inputs, such as concat.
bool arbitrary_inputs = false; bool arbitrary_inputs = false;
// Input and weight indexes. Unable to separate the two because of ops such as // Input indexes -> input tensor property.
// ADD. std::vector<std::pair<int, TensorProperty>> inputs = {};
std::vector<int> input_indexes = {}; // Output indexes -> output tensor property.
std::vector<std::pair<int, TensorProperty>> outputs = {};
// Output indexes
std::vector<int> output_indexes = {};
// Bias indexes. // Bias indexes.
std::vector<int> biases = {}; std::vector<int> biases = {};
// Constraints. // Constraints.
bool restrict_same_input_output_scale = false; bool restrict_same_input_output_scale = false;
bool restriction_on_output = false;
std::pair<float, float> restricted_value_on_output = {0.0, 0.0};
// Op version. // Op version.
int version = 1; int version = 1;

View File

@ -357,6 +357,9 @@ TfLiteStatus SymmetricPerChannelBiasQuantize(ModelT* model, TensorT* tensor,
TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel, TfLiteStatus QuantizeWeight(ModelT* model, TensorT* tensor, bool per_channel,
int per_axis_index) { int per_axis_index) {
// TODO(suharshs): Currently we conflate quantizing weights and constants. Its
// possible that the right thing to do is asymmetric quantize the weight. Add
// support for this.
if (per_channel) { if (per_channel) {
return SymmetricQuantizeTensorPerChannel(model, tensor, per_axis_index); return SymmetricQuantizeTensorPerChannel(model, tensor, per_axis_index);
} else { } else {

View File

@ -276,12 +276,12 @@ TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
} }
// Basically only Concat passes this check. // Basically only Concat passes this check.
if (!property.restrict_same_input_output_scale || if (!property.restrict_same_input_output_scale ||
(property.input_indexes.size() == 1 && (property.inputs.size() == 1 && property.outputs.size() == 1 &&
property.output_indexes.size() == 1 && property.biases.empty())) { property.biases.empty())) {
continue; continue;
} }
// If ApplyConstraintsnd requant is needed, use the min of min and max of // If ApplyConstraints and requant is needed, use the min of min and max
// max, which means using the scale and zero point of output. // of max, which means using the scale and zero point of output.
TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get(); TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get();
if (!utils::QuantizationParametersExist(output_tensor)) { if (!utils::QuantizationParametersExist(output_tensor)) {
error_reporter->Report( error_reporter->Report(
@ -332,24 +332,23 @@ TfLiteStatus ApplyConstraints(ModelT* model, ErrorReporter* error_reporter) {
return kTfLiteOk; return kTfLiteOk;
} }
std::vector<int> GetInputIndexes(const OperatorT* op, std::vector<std::pair<int, operator_property::TensorProperty>> GetInputs(
operator_property::OperatorProperty property) { const OperatorT* op, operator_property::OperatorProperty property) {
std::vector<int> input_indexes; std::vector<std::pair<int, operator_property::TensorProperty>> inputs;
if (property.arbitrary_inputs || !property.quantizable) { if (property.arbitrary_inputs || !property.quantizable) {
for (int i = 0; i < op->inputs.size(); ++i) { for (int i = 0; i < op->inputs.size(); ++i) {
input_indexes.push_back(i); inputs.push_back({i, {}});
} }
} else { } else {
input_indexes = property.input_indexes; inputs = property.inputs;
} }
return input_indexes; return inputs;
} }
bool ShouldRestrictSameInputOutputScale( bool ShouldRestrictSameInputOutputScale(
operator_property::OperatorProperty property) { operator_property::OperatorProperty property) {
return (property.input_indexes.size() == 1 && return (property.inputs.size() == 1 && property.outputs.size() == 1 &&
property.output_indexes.size() == 1 && property.biases.empty() && property.biases.empty() && property.restrict_same_input_output_scale);
property.restrict_same_input_output_scale);
} }
bool IsSubgraphInput(SubGraphT* subgraph, int32_t index) { bool IsSubgraphInput(SubGraphT* subgraph, int32_t index) {
@ -362,10 +361,13 @@ bool IsSubgraphInput(SubGraphT* subgraph, int32_t index) {
} }
// Quantize the op input. Will increment op_idx if ops are added. // Quantize the op input. Will increment op_idx if ops are added.
TfLiteStatus QuantizeOpInput(ModelT* model, int32_t subgraph_idx, TfLiteStatus QuantizeOpInput(
size_t* op_idx, ModelT* model, int32_t subgraph_idx, size_t* op_idx,
operator_property::OperatorProperty property, operator_property::OperatorProperty property,
int32_t input_idx, ErrorReporter* error_reporter) { const std::pair<int32_t, operator_property::TensorProperty>& input,
ErrorReporter* error_reporter) {
int32_t input_idx = input.first;
operator_property::TensorProperty tensor_property = input.second;
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
OperatorT* op = subgraph->operators[*op_idx].get(); OperatorT* op = subgraph->operators[*op_idx].get();
const BuiltinOperator op_code = const BuiltinOperator op_code =
@ -384,8 +386,11 @@ TfLiteStatus QuantizeOpInput(ModelT* model, int32_t subgraph_idx,
if (property.quantizable && !is_input_quantized) { if (property.quantizable && !is_input_quantized) {
// The operation is quantizable, but the input isn't yet quantized. // The operation is quantizable, but the input isn't yet quantized.
if (utils::HasBuffer(model, subgraph, tensor_idx)) { if (utils::HasBuffer(model, subgraph, tensor_idx)) {
if (utils::QuantizeWeight(model, tensor, property.per_axis, // TODO(suharshs): Look at consumers, throw error if one consumer is
property.per_axis_index) == kTfLiteError) { // per-channel and one per-layer.
if (utils::QuantizeWeight(model, tensor, tensor_property.per_axis,
tensor_property.per_axis_index) ==
kTfLiteError) {
error_reporter->Report( error_reporter->Report(
"Unable to quantize buffer or min/max value for input %d " "Unable to quantize buffer or min/max value for input %d "
"in op %s in subgraph %d, node: %d", "in op %s in subgraph %d, node: %d",
@ -393,6 +398,7 @@ TfLiteStatus QuantizeOpInput(ModelT* model, int32_t subgraph_idx,
return kTfLiteError; return kTfLiteError;
} }
} else if (utils::HasMinMax(tensor)) { } else if (utils::HasMinMax(tensor)) {
// TODO(suharshs): Handle per-channel dynamic tensor.
if (IsSubgraphInput(subgraph, tensor_idx)) { if (IsSubgraphInput(subgraph, tensor_idx)) {
utils::QuantizeActivation(tensor); utils::QuantizeActivation(tensor);
} else { } else {
@ -442,11 +448,13 @@ TfLiteStatus QuantizeOpInput(ModelT* model, int32_t subgraph_idx,
} }
// Quantize the op output. // Quantize the op output.
TfLiteStatus QuantizeOpOutput(ModelT* model, int32_t subgraph_idx, TfLiteStatus QuantizeOpOutput(
int32_t op_idx, ModelT* model, int32_t subgraph_idx, int32_t op_idx,
operator_property::OperatorProperty property, operator_property::OperatorProperty property,
int32_t output_idx, const std::pair<int32_t, operator_property::TensorProperty>& output,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
int32_t output_idx = output.first;
operator_property::TensorProperty tensor_property = output.second;
// If the operator is not quantizable, we don't need to do anything for the // If the operator is not quantizable, we don't need to do anything for the
// output. // output.
if (!property.quantizable) { if (!property.quantizable) {
@ -470,16 +478,16 @@ TfLiteStatus QuantizeOpOutput(ModelT* model, int32_t subgraph_idx,
// Copy quantization parameter. For average pool, max pool, etc // Copy quantization parameter. For average pool, max pool, etc
// min/max can be different but we want them to be the same. // min/max can be different but we want them to be the same.
// Get scale and zero point of input. // Get scale and zero point of input.
if (property.input_indexes[0] >= op->inputs.size()) { if (property.inputs[0].first >= op->inputs.size()) {
error_reporter->Report( error_reporter->Report(
"Required input index %d is larger than the input length of " "Required input index %d is larger than the input length of "
"op %s at index %d in subgraph %d", "op %s at index %d in subgraph %d",
property.input_indexes[0], op->inputs.size(), property.inputs[0].first, op->inputs.size(),
EnumNameBuiltinOperator(op_code), op_idx, subgraph_idx); EnumNameBuiltinOperator(op_code), op_idx, subgraph_idx);
return kTfLiteError; return kTfLiteError;
} }
const int input_index = op->inputs[property.input_indexes[0]]; const int input_tensor_idx = op->inputs[property.inputs[0].first];
TensorT* input_tensor = subgraph->tensors[input_index].get(); TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get();
if (input_tensor->quantization->scale.size() != 1 || if (input_tensor->quantization->scale.size() != 1 ||
input_tensor->quantization->zero_point.size() != 1 || input_tensor->quantization->zero_point.size() != 1 ||
input_tensor->quantization->min.size() != 1 || input_tensor->quantization->min.size() != 1 ||
@ -514,8 +522,8 @@ TfLiteStatus QuantizeOpOutput(ModelT* model, int32_t subgraph_idx,
output_tensor->quantization->min.push_back(min); output_tensor->quantization->min.push_back(min);
output_tensor->quantization->max.push_back(max); output_tensor->quantization->max.push_back(max);
output_tensor->type = TensorType_INT8; output_tensor->type = TensorType_INT8;
} else if (property.restriction_on_output) { } else if (tensor_property.restriction) {
const auto scale_and_zp = property.restricted_value_on_output; const auto scale_and_zp = tensor_property.restricted_value;
// Apply to output. // Apply to output.
output_tensor->quantization = absl::make_unique<QuantizationParametersT>(); output_tensor->quantization = absl::make_unique<QuantizationParametersT>();
output_tensor->quantization->scale.push_back(scale_and_zp.first); output_tensor->quantization->scale.push_back(scale_and_zp.first);
@ -557,15 +565,17 @@ TfLiteStatus QuantizeWeightsInputOutput(ModelT* model, bool allow_float,
} }
// Quantize operator inputs/weights. // Quantize operator inputs/weights.
for (const int input_idx : GetInputIndexes(op, property)) { for (const std::pair<int, operator_property::TensorProperty>& input :
TF_LITE_ENSURE_STATUS(QuantizeOpInput( GetInputs(op, property)) {
model, subgraph_idx, &op_idx, property, input_idx, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeOpInput(model, subgraph_idx, &op_idx,
property, input, error_reporter));
} }
// Quantize operator outputs. // Quantize operator outputs.
for (const int output_idx : property.output_indexes) { for (const std::pair<int, operator_property::TensorProperty>& output :
property.outputs) {
TF_LITE_ENSURE_STATUS(QuantizeOpOutput( TF_LITE_ENSURE_STATUS(QuantizeOpOutput(
model, subgraph_idx, op_idx, property, output_idx, error_reporter)); model, subgraph_idx, op_idx, property, output, error_reporter));
} }
} }
} }
@ -601,7 +611,7 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
if (utils::HasBuffer(model, subgraph, op->inputs[bias_idx])) { if (utils::HasBuffer(model, subgraph, op->inputs[bias_idx])) {
TensorT* bias_tensor = TensorT* bias_tensor =
subgraph->tensors[op->inputs[bias_idx]].get(); subgraph->tensors[op->inputs[bias_idx]].get();
if (property.input_indexes.size() != 2) { if (property.inputs.size() != 2) {
error_reporter->Report( error_reporter->Report(
"Expect the input length of " "Expect the input length of "
"op %s at index %d in subgraph %d to be 2", "op %s at index %d in subgraph %d to be 2",
@ -610,12 +620,15 @@ TfLiteStatus QuantizeBiases(ModelT* model, ErrorReporter* error_reporter) {
return kTfLiteError; return kTfLiteError;
} }
TensorT* input_tensor = TensorT* input_tensor =
subgraph->tensors[op->inputs[property.input_indexes[0]]].get(); subgraph->tensors[op->inputs[property.inputs[0].first]].get();
TensorT* weight_tensor = TensorT* weight_tensor =
subgraph->tensors[op->inputs[property.input_indexes[1]]].get(); subgraph->tensors[op->inputs[property.inputs[1].first]].get();
TF_LITE_ENSURE_STATUS(QuantizeBias( operator_property::TensorProperty weight_property =
model, input_tensor, weight_tensor, bias_tensor, property.inputs[1].second;
property.per_axis, property.per_axis_index, error_reporter)); TF_LITE_ENSURE_STATUS(
QuantizeBias(model, input_tensor, weight_tensor, bias_tensor,
weight_property.per_axis,
weight_property.per_axis_index, error_reporter));
} }
} }
} }