Allow easier re-running of toco on already quantized, mixed-bit-depth

tflite models, by:
  - fixing a bug whereby PropagateArrayDataTypes was brutally propagating
    the data type of its first input onto all outputs;
  - not letting QUANTIZED_UINT8 override existing non-8bit quantized data
    types.

PiperOrigin-RevId: 229858747
This commit is contained in:
A. Unique TensorFlower 2019-01-17 19:01:05 -08:00 committed by TensorFlower Gardener
parent a905c626ca
commit c10447d959
3 changed files with 45 additions and 2 deletions

View File

@ -147,12 +147,26 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
if (final_output_mul->type != OperatorType::kMul) {
return ::tensorflow::Status::OK();
}
// final_output_mul->outputs[0] would be one of the two outputs of our
// LstmCell. Exit if it does not already have a data type.
// We won't be able to propagate data types through a fused LstmCell.
if (model->GetArray(final_output_mul->outputs[0]).data_type ==
ArrayDataType::kNone) {
return ::tensorflow::Status::OK();
}
Operator *state_output_tanh, *fc_output_sig;
if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
&state_output_tanh, OperatorType::kLogistic,
&fc_output_sig)) {
return ::tensorflow::Status::OK();
}
// state_output_tanh->inputs[0] would be one of the two outputs of our
// LstmCell. Exit if it does not already have a data type.
// We won't be able to propagate data types through a fused LstmCell.
if (model->GetArray(state_output_tanh->inputs[0]).data_type ==
ArrayDataType::kNone) {
return ::tensorflow::Status::OK();
}
// State output TanH
// (We don't count an operator as ID'd until we verify it has the correct
@ -262,11 +276,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
const string& concat_temp_array_name =
AvailableArrayName(*model, base_name + "concat_temp");
model->GetOrCreateArray(concat_temp_array_name);
auto& concat_temp_array = model->GetOrCreateArray(concat_temp_array_name);
concat_temp_array.data_type =
model->GetArray(concat_inputs->outputs[0]).data_type;
lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
const string& activ_temp_array_name =
AvailableArrayName(*model, base_name + "activ_temp");
model->GetOrCreateArray(activ_temp_array_name);
auto& activ_temp_array = model->GetOrCreateArray(activ_temp_array_name);
activ_temp_array.data_type =
model->GetArray(fully_connected->outputs[0]).data_type;
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name;
AddMessageF("Created temp outputs %s and %s on operator %s",
concat_temp_array_name, activ_temp_array_name,

View File

@ -266,6 +266,14 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
model->GetArray(op->outputs[1]).data_type = unique_op->idx_out_type;
break;
}
case OperatorType::kLstmCell: {
// It's tricky to propagate data types through a LstmCell, as that has
// multiple inputs and outputs, and there are quantized cases with
// mixed (8bit vs 16bit) cases. Fortunately, that should never be needed,
// as the data formats, such as TFLITE, that have LstmCell nodes, also
// have data type fields for all their arrays.
break;
}
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);

View File

@ -178,6 +178,23 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
// Ignore non-real data types.
continue;
}
// The enum value QUANTIZED_UINT8 for --inference_type and
// --inference_input_type has long meant just 'QUANTIZED', being used as
// well in mixed 8-bit / 16-bit quantized models. However,
// ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
// and people have run into issues in the situation where they have an
// already mixed 8-bit / 16-bit quantized model in TFLITE format and
// want to run it again through toco, without having to re-specify all the
// extra array info that was used in the (complicated) process of initially
// quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
// just work in that case, we implement the logic that when an array is
// already quantized, if --inference_type is quantized (so we're not
// asking to dequantize here), no change of quantized data type is to be
// recorded.
if (array->data_type != toco::ArrayDataType::kFloat &&
type != toco::ArrayDataType::kFloat) {
continue;
}
array->final_data_type = type;
}