Fix "Converting DataType 'INVALID' to MLIR Type" bug

PiperOrigin-RevId: 276387998
Change-Id: Ide7dd335e1d1c1463e318ae39de3ac84a9aeeddf
This commit is contained in:
Yu-Cheng Ling 2019-10-23 17:44:13 -07:00 committed by TensorFlower Gardener
parent 79230ddda7
commit d752708859

View File

@ -106,10 +106,30 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<double> node_maxs;
tensorflow::DataType inference_type =
ConvertIODataTypeToDataType(toco_flags.inference_type());
// Build a map from placeholder to data types.
llvm::StringMap<DataType> placeholder_data_type_map;
for (const NodeDef& node_def : input.node()) {
if (node_def.op() == "Placeholder" && node_def.attr().count("dtype") > 0) {
placeholder_data_type_map[node_def.name()] =
node_def.attr().at("dtype").type();
}
}
for (auto& flag : model_flags.input_arrays()) {
// TOCO doesn't required `data_type` to be filled for every input.
// If it's not filled, try to get the data type from the placeholder.
auto toco_data_type = flag.data_type();
DataType data_type;
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN &&
placeholder_data_type_map.find(flag.name()) !=
placeholder_data_type_map.end()) {
data_type = placeholder_data_type_map[flag.name()];
} else {
data_type = ConvertIODataTypeToDataType(toco_data_type);
}
node_names.push_back(flag.name());
node_dtypes.push_back(
DataType_Name(ConvertIODataTypeToDataType(flag.data_type())));
node_dtypes.push_back(DataType_Name(data_type));
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));