Fix "Converting DataType 'INVALID' to MLIR Type" bug
PiperOrigin-RevId: 276387998 Change-Id: Ide7dd335e1d1c1463e318ae39de3ac84a9aeeddf
This commit is contained in:
parent
79230ddda7
commit
d752708859
@ -106,10 +106,30 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
std::vector<double> node_maxs;
|
std::vector<double> node_maxs;
|
||||||
tensorflow::DataType inference_type =
|
tensorflow::DataType inference_type =
|
||||||
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
||||||
|
|
||||||
|
// Build a map from placeholder to data types.
|
||||||
|
llvm::StringMap<DataType> placeholder_data_type_map;
|
||||||
|
for (const NodeDef& node_def : input.node()) {
|
||||||
|
if (node_def.op() == "Placeholder" && node_def.attr().count("dtype") > 0) {
|
||||||
|
placeholder_data_type_map[node_def.name()] =
|
||||||
|
node_def.attr().at("dtype").type();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& flag : model_flags.input_arrays()) {
|
for (auto& flag : model_flags.input_arrays()) {
|
||||||
|
// TOCO doesn't required `data_type` to be filled for every input.
|
||||||
|
// If it's not filled, try to get the data type from the placeholder.
|
||||||
|
auto toco_data_type = flag.data_type();
|
||||||
|
DataType data_type;
|
||||||
|
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN &&
|
||||||
|
placeholder_data_type_map.find(flag.name()) !=
|
||||||
|
placeholder_data_type_map.end()) {
|
||||||
|
data_type = placeholder_data_type_map[flag.name()];
|
||||||
|
} else {
|
||||||
|
data_type = ConvertIODataTypeToDataType(toco_data_type);
|
||||||
|
}
|
||||||
node_names.push_back(flag.name());
|
node_names.push_back(flag.name());
|
||||||
node_dtypes.push_back(
|
node_dtypes.push_back(DataType_Name(data_type));
|
||||||
DataType_Name(ConvertIODataTypeToDataType(flag.data_type())));
|
|
||||||
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
|
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||||
flag.shape().dims().end()));
|
flag.shape().dims().end()));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user