diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 96962b470d7..7e7dd02e709 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -106,10 +106,30 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, std::vector node_maxs; tensorflow::DataType inference_type = ConvertIODataTypeToDataType(toco_flags.inference_type()); + + // Build a map from placeholder to data types. + llvm::StringMap placeholder_data_type_map; + for (const NodeDef& node_def : input.node()) { + if (node_def.op() == "Placeholder" && node_def.attr().count("dtype") > 0) { + placeholder_data_type_map[node_def.name()] = + node_def.attr().at("dtype").type(); + } + } + for (auto& flag : model_flags.input_arrays()) { + // TOCO doesn't required `data_type` to be filled for every input. + // If it's not filled, try to get the data type from the placeholder. + auto toco_data_type = flag.data_type(); + DataType data_type; + if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN && + placeholder_data_type_map.find(flag.name()) != + placeholder_data_type_map.end()) { + data_type = placeholder_data_type_map[flag.name()]; + } else { + data_type = ConvertIODataTypeToDataType(toco_data_type); + } node_names.push_back(flag.name()); - node_dtypes.push_back( - DataType_Name(ConvertIODataTypeToDataType(flag.data_type()))); + node_dtypes.push_back(DataType_Name(data_type)); node_shapes.push_back(std::vector(flag.shape().dims().begin(), flag.shape().dims().end()));