diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc index 089ecee959a..65dbb8a1766 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc @@ -147,12 +147,26 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, if (final_output_mul->type != OperatorType::kMul) { return ::tensorflow::Status::OK(); } + // final_output_mul->outputs[0] would be one of the two outputs of our + // LstmCell. Exit if it does not already have a data type. + // We won't be able to propagate data types through a fused LstmCell. + if (model->GetArray(final_output_mul->outputs[0]).data_type == + ArrayDataType::kNone) { + return ::tensorflow::Status::OK(); + } Operator *state_output_tanh, *fc_output_sig; if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, &state_output_tanh, OperatorType::kLogistic, &fc_output_sig)) { return ::tensorflow::Status::OK(); } + // state_output_tanh->inputs[0] would be one of the two outputs of our + // LstmCell. Exit if it does not already have a data type. + // We won't be able to propagate data types through a fused LstmCell. + if (model->GetArray(state_output_tanh->inputs[0]).data_type == + ArrayDataType::kNone) { + return ::tensorflow::Status::OK(); + } // State output TanH // (We don't count an operator as ID'd until we verify it has the correct @@ -262,11 +276,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT])); const string& concat_temp_array_name = AvailableArrayName(*model, base_name + "concat_temp"); - model->GetOrCreateArray(concat_temp_array_name); + auto& concat_temp_array = model->GetOrCreateArray(concat_temp_array_name); + concat_temp_array.data_type = + model->GetArray(concat_inputs->outputs[0]).data_type; lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; const string& activ_temp_array_name = AvailableArrayName(*model, base_name + "activ_temp"); - model->GetOrCreateArray(activ_temp_array_name); + auto& activ_temp_array = model->GetOrCreateArray(activ_temp_array_name); + activ_temp_array.data_type = + model->GetArray(fully_connected->outputs[0]).data_type; lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name; AddMessageF("Created temp outputs %s and %s on operator %s", concat_temp_array_name, activ_temp_array_name, diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index 7aec6728da6..d546ff4ec0d 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -266,6 +266,14 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, model->GetArray(op->outputs[1]).data_type = unique_op->idx_out_type; break; } + case OperatorType::kLstmCell: { + // It's tricky to propagate data types through a LstmCell, as that has + // multiple inputs and outputs, and there are quantized cases with + // mixed (8bit vs 16bit) cases. Fortunately, that should never be needed, + // as the data formats, such as TFLITE, that have LstmCell nodes, also + // have data type fields for all their arrays. + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index 55a454e66de..61dca56d23e 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -178,6 +178,23 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) { // Ignore non-real data types. continue; } + // The enum value QUANTIZED_UINT8 for --inference_type and + // --inference_input_type has long meant just 'QUANTIZED', being used as + // well in mixed 8-bit / 16-bit quantized models. However, + // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit, + // and people have run into issues in the situation where they have an + // already mixed 8-bit / 16-bit quantized model in TFLITE format and + // want to run it again through toco, without having to re-specify all the + // extra array info that was used in the (complicated) process of initially + // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8 + // just work in that case, we implement the logic that when an array is + // already quantized, if --inference_type is quantized (so we're not + // asking to dequantize here), no change of quantized data type is to be + // recorded. + if (array->data_type != toco::ArrayDataType::kFloat && + type != toco::ArrayDataType::kFloat) { + continue; + } array->final_data_type = type; }