Allow easier re-running of toco on already quantized, mixed-bit-depth
tflite models, by: - fixing a bug whereby PropagateArrayDataTypes was brutally propagating the data type of its first input onto all outputs; - not letting QUANTIZED_UINT8 override existing non-8bit quantized data types. PiperOrigin-RevId: 229858747
This commit is contained in:
parent
a905c626ca
commit
c10447d959
@ -147,12 +147,26 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
|
||||
if (final_output_mul->type != OperatorType::kMul) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// final_output_mul->outputs[0] would be one of the two outputs of our
|
||||
// LstmCell. Exit if it does not already have a data type.
|
||||
// We won't be able to propagate data types through a fused LstmCell.
|
||||
if (model->GetArray(final_output_mul->outputs[0]).data_type ==
|
||||
ArrayDataType::kNone) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
Operator *state_output_tanh, *fc_output_sig;
|
||||
if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
|
||||
&state_output_tanh, OperatorType::kLogistic,
|
||||
&fc_output_sig)) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// state_output_tanh->inputs[0] would be one of the two outputs of our
|
||||
// LstmCell. Exit if it does not already have a data type.
|
||||
// We won't be able to propagate data types through a fused LstmCell.
|
||||
if (model->GetArray(state_output_tanh->inputs[0]).data_type ==
|
||||
ArrayDataType::kNone) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// State output TanH
|
||||
// (We don't count an operator as ID'd until we verify it has the correct
|
||||
@ -262,11 +276,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
|
||||
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
|
||||
const string& concat_temp_array_name =
|
||||
AvailableArrayName(*model, base_name + "concat_temp");
|
||||
model->GetOrCreateArray(concat_temp_array_name);
|
||||
auto& concat_temp_array = model->GetOrCreateArray(concat_temp_array_name);
|
||||
concat_temp_array.data_type =
|
||||
model->GetArray(concat_inputs->outputs[0]).data_type;
|
||||
lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
|
||||
const string& activ_temp_array_name =
|
||||
AvailableArrayName(*model, base_name + "activ_temp");
|
||||
model->GetOrCreateArray(activ_temp_array_name);
|
||||
auto& activ_temp_array = model->GetOrCreateArray(activ_temp_array_name);
|
||||
activ_temp_array.data_type =
|
||||
model->GetArray(fully_connected->outputs[0]).data_type;
|
||||
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name;
|
||||
AddMessageF("Created temp outputs %s and %s on operator %s",
|
||||
concat_temp_array_name, activ_temp_array_name,
|
||||
|
@ -266,6 +266,14 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||
model->GetArray(op->outputs[1]).data_type = unique_op->idx_out_type;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kLstmCell: {
|
||||
// It's tricky to propagate data types through a LstmCell, as that has
|
||||
// multiple inputs and outputs, and there are quantized cases with
|
||||
// mixed (8bit vs 16bit) cases. Fortunately, that should never be needed,
|
||||
// as the data formats, such as TFLITE, that have LstmCell nodes, also
|
||||
// have data type fields for all their arrays.
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// These operators produce outputs with the same type as their 1st input
|
||||
CHECK_GT(op->inputs.size(), 0);
|
||||
|
@ -178,6 +178,23 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
|
||||
// Ignore non-real data types.
|
||||
continue;
|
||||
}
|
||||
// The enum value QUANTIZED_UINT8 for --inference_type and
|
||||
// --inference_input_type has long meant just 'QUANTIZED', being used as
|
||||
// well in mixed 8-bit / 16-bit quantized models. However,
|
||||
// ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
|
||||
// and people have run into issues in the situation where they have an
|
||||
// already mixed 8-bit / 16-bit quantized model in TFLITE format and
|
||||
// want to run it again through toco, without having to re-specify all the
|
||||
// extra array info that was used in the (complicated) process of initially
|
||||
// quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
|
||||
// just work in that case, we implement the logic that when an array is
|
||||
// already quantized, if --inference_type is quantized (so we're not
|
||||
// asking to dequantize here), no change of quantized data type is to be
|
||||
// recorded.
|
||||
if (array->data_type != toco::ArrayDataType::kFloat &&
|
||||
type != toco::ArrayDataType::kFloat) {
|
||||
continue;
|
||||
}
|
||||
|
||||
array->final_data_type = type;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user