diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 7e7dd02e709..76759a5e851 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -169,6 +169,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, specs.prune_unused_nodes = true; specs.convert_legacy_fed_inputs = true; specs.graph_as_function = false; + specs.upgrade_legacy = true; WarningUnusedFlags(model_flags, toco_flags); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index fe14b34eb50..ab909ea9538 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -536,12 +536,21 @@ Status ImporterBase::AddNodesToShapeRefiner() { "Input arrays can only have op with single output. Node op:", node_name)); } - // For single output nodes, replace them with Placeholder node + // For single output nodes, replace them with Placeholder node. + DataType dtype = it->second.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(0); + } TF_ASSIGN_OR_RETURN( - node, ReplaceWithPlaceholderNode(it->second.shape, - it->second.imported_dtype, node)); + node, ReplaceWithPlaceholderNode(it->second.shape, dtype, node)); } else { node->AddAttr("shape", it->second.shape); + DataType dtype = it->second.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(0); + } node->AddAttr("dtype", it->second.imported_dtype); } } @@ -1299,7 +1308,6 @@ Status ImporterBase::ConvertNode(const Node& node) { const auto& node_def = node.def(); mlir::OperationState result(GetLocation(node_def), op_name); - for (int i = 0; i < node.num_outputs(); ++i) { // The backedge has been removed, so we shouldn't count the corresponding // output from the src node when converting to an operation. @@ -1688,39 +1696,38 @@ StatusOr GraphDefImporter::InferMainFunctionType( } } + // Starts to construct the function type. + mlir::Builder builder(context); + llvm::SmallVector arg_types; + arg_types.reserve(specs.inputs.size()); int i = 0; for (auto it : specs.inputs) { if (arg_nodes->at(i++).node == nullptr) { return errors::InvalidArgument("Input ", it.first, " was not found in graph"); } + mlir::Type element_type; + const auto& node_info = it.second; + DataType imported_dtype = node_info.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (imported_dtype == DT_INVALID) { + imported_dtype = arg_nodes->back().node->output_type(0); + } + TF_RETURN_IF_ERROR( + ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type)); + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); + arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); } + + llvm::SmallVector ret_types; + ret_types.reserve(specs.output_arrays.size()); for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) { if (ret_nodes->at(i).node == nullptr) { return errors::InvalidArgument("Output ", specs.output_arrays_order[i], " was not found in graph"); } } - - // Starts to construct the function type. - llvm::SmallVector arg_types; - llvm::SmallVector ret_types; - arg_types.reserve(specs.inputs.size()); - ret_types.reserve(specs.output_arrays.size()); - mlir::Builder builder(context); - - // Input nodes as function arguments. - for (const auto& input : specs.inputs) { - mlir::Type element_type; - const auto& node_info = input.second; - TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype, - builder, &element_type)); - llvm::SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); - arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); - } - - // Output nodes as function returns. for (const auto& ret : *ret_nodes) { if (ret.node->num_outputs() <= ret.index) { return errors::InvalidArgument("Invalid output index ", ret.index,