Fix "Converting DataType 'INVALID' to MLIR Type" bug
PiperOrigin-RevId: 276387998 Change-Id: Ide7dd335e1d1c1463e318ae39de3ac84a9aeeddf
This commit is contained in:
parent
79230ddda7
commit
d752708859
@ -106,10 +106,30 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
std::vector<double> node_maxs;
|
||||
tensorflow::DataType inference_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
||||
|
||||
// Build a map from placeholder to data types.
|
||||
llvm::StringMap<DataType> placeholder_data_type_map;
|
||||
for (const NodeDef& node_def : input.node()) {
|
||||
if (node_def.op() == "Placeholder" && node_def.attr().count("dtype") > 0) {
|
||||
placeholder_data_type_map[node_def.name()] =
|
||||
node_def.attr().at("dtype").type();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& flag : model_flags.input_arrays()) {
|
||||
// TOCO doesn't required `data_type` to be filled for every input.
|
||||
// If it's not filled, try to get the data type from the placeholder.
|
||||
auto toco_data_type = flag.data_type();
|
||||
DataType data_type;
|
||||
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN &&
|
||||
placeholder_data_type_map.find(flag.name()) !=
|
||||
placeholder_data_type_map.end()) {
|
||||
data_type = placeholder_data_type_map[flag.name()];
|
||||
} else {
|
||||
data_type = ConvertIODataTypeToDataType(toco_data_type);
|
||||
}
|
||||
node_names.push_back(flag.name());
|
||||
node_dtypes.push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(flag.data_type())));
|
||||
node_dtypes.push_back(DataType_Name(data_type));
|
||||
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user