Apply functionalization and also use existing data type if it is not specified
This patch contains two changes: - apply functionalization pass before the graph is imported to MLIR. - If the input data type isn't specified, we use the one in the graph. PiperOrigin-RevId: 276629651 Change-Id: I9251ceac66e245836dba8fcac8ed4a576c4f18ae
This commit is contained in:
parent
44848dc0b3
commit
087b39e079
@ -169,6 +169,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
specs.prune_unused_nodes = true;
|
||||
specs.convert_legacy_fed_inputs = true;
|
||||
specs.graph_as_function = false;
|
||||
specs.upgrade_legacy = true;
|
||||
WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
@ -536,12 +536,21 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
"Input arrays can only have op with single output. Node op:",
|
||||
node_name));
|
||||
}
|
||||
// For single output nodes, replace them with Placeholder node
|
||||
// For single output nodes, replace them with Placeholder node.
|
||||
DataType dtype = it->second.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (dtype == DT_INVALID) {
|
||||
dtype = node->output_type(0);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
node, ReplaceWithPlaceholderNode(it->second.shape,
|
||||
it->second.imported_dtype, node));
|
||||
node, ReplaceWithPlaceholderNode(it->second.shape, dtype, node));
|
||||
} else {
|
||||
node->AddAttr("shape", it->second.shape);
|
||||
DataType dtype = it->second.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (dtype == DT_INVALID) {
|
||||
dtype = node->output_type(0);
|
||||
}
|
||||
node->AddAttr("dtype", it->second.imported_dtype);
|
||||
}
|
||||
}
|
||||
@ -1299,7 +1308,6 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
|
||||
const auto& node_def = node.def();
|
||||
mlir::OperationState result(GetLocation(node_def), op_name);
|
||||
|
||||
for (int i = 0; i < node.num_outputs(); ++i) {
|
||||
// The backedge has been removed, so we shouldn't count the corresponding
|
||||
// output from the src node when converting to an operation.
|
||||
@ -1688,39 +1696,38 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
||||
}
|
||||
}
|
||||
|
||||
// Starts to construct the function type.
|
||||
mlir::Builder builder(context);
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
arg_types.reserve(specs.inputs.size());
|
||||
int i = 0;
|
||||
for (auto it : specs.inputs) {
|
||||
if (arg_nodes->at(i++).node == nullptr) {
|
||||
return errors::InvalidArgument("Input ", it.first,
|
||||
" was not found in graph");
|
||||
}
|
||||
mlir::Type element_type;
|
||||
const auto& node_info = it.second;
|
||||
DataType imported_dtype = node_info.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (imported_dtype == DT_INVALID) {
|
||||
imported_dtype = arg_nodes->back().node->output_type(0);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
::tensorflow::ConvertDataType(imported_dtype, builder, &element_type));
|
||||
llvm::SmallVector<int64_t, 4> shape;
|
||||
TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
|
||||
arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
ret_types.reserve(specs.output_arrays.size());
|
||||
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
|
||||
if (ret_nodes->at(i).node == nullptr) {
|
||||
return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
|
||||
" was not found in graph");
|
||||
}
|
||||
}
|
||||
|
||||
// Starts to construct the function type.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
arg_types.reserve(specs.inputs.size());
|
||||
ret_types.reserve(specs.output_arrays.size());
|
||||
mlir::Builder builder(context);
|
||||
|
||||
// Input nodes as function arguments.
|
||||
for (const auto& input : specs.inputs) {
|
||||
mlir::Type element_type;
|
||||
const auto& node_info = input.second;
|
||||
TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype,
|
||||
builder, &element_type));
|
||||
llvm::SmallVector<int64_t, 4> shape;
|
||||
TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
|
||||
arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
|
||||
}
|
||||
|
||||
// Output nodes as function returns.
|
||||
for (const auto& ret : *ret_nodes) {
|
||||
if (ret.node->num_outputs() <= ret.index) {
|
||||
return errors::InvalidArgument("Invalid output index ", ret.index,
|
||||
|
Loading…
x
Reference in New Issue
Block a user