Apply functionalization and also use existing data type if it is not specified

This patch contains two changes:

- apply functionalization pass before the graph is imported to MLIR.
- If the input data type isn't specified, we use the one in the graph.

PiperOrigin-RevId: 276629651
Change-Id: I9251ceac66e245836dba8fcac8ed4a576c4f18ae
This commit is contained in:
Feng Liu 2019-10-24 22:08:06 -07:00 committed by TensorFlower Gardener
parent 44848dc0b3
commit 087b39e079
2 changed files with 32 additions and 24 deletions

View File

@ -169,6 +169,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
specs.prune_unused_nodes = true;
specs.convert_legacy_fed_inputs = true;
specs.graph_as_function = false;
specs.upgrade_legacy = true;
WarningUnusedFlags(model_flags, toco_flags);
TF_ASSIGN_OR_RETURN(

View File

@ -536,12 +536,21 @@ Status ImporterBase::AddNodesToShapeRefiner() {
"Input arrays can only have op with single output. Node op:",
node_name));
}
// For single output nodes, replace them with Placeholder node
// For single output nodes, replace them with Placeholder node.
DataType dtype = it->second.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
}
TF_ASSIGN_OR_RETURN(
node, ReplaceWithPlaceholderNode(it->second.shape,
it->second.imported_dtype, node));
node, ReplaceWithPlaceholderNode(it->second.shape, dtype, node));
} else {
node->AddAttr("shape", it->second.shape);
DataType dtype = it->second.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
}
node->AddAttr("dtype", it->second.imported_dtype);
}
}
@ -1299,7 +1308,6 @@ Status ImporterBase::ConvertNode(const Node& node) {
const auto& node_def = node.def();
mlir::OperationState result(GetLocation(node_def), op_name);
for (int i = 0; i < node.num_outputs(); ++i) {
// The backedge has been removed, so we shouldn't count the corresponding
// output from the src node when converting to an operation.
@ -1688,39 +1696,38 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
}
}
// Starts to construct the function type.
mlir::Builder builder(context);
llvm::SmallVector<mlir::Type, 4> arg_types;
arg_types.reserve(specs.inputs.size());
int i = 0;
for (auto it : specs.inputs) {
if (arg_nodes->at(i++).node == nullptr) {
return errors::InvalidArgument("Input ", it.first,
" was not found in graph");
}
mlir::Type element_type;
const auto& node_info = it.second;
DataType imported_dtype = node_info.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (imported_dtype == DT_INVALID) {
imported_dtype = arg_nodes->back().node->output_type(0);
}
TF_RETURN_IF_ERROR(
::tensorflow::ConvertDataType(imported_dtype, builder, &element_type));
llvm::SmallVector<int64_t, 4> shape;
TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
}
llvm::SmallVector<mlir::Type, 4> ret_types;
ret_types.reserve(specs.output_arrays.size());
for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) {
if (ret_nodes->at(i).node == nullptr) {
return errors::InvalidArgument("Output ", specs.output_arrays_order[i],
" was not found in graph");
}
}
// Starts to construct the function type.
llvm::SmallVector<mlir::Type, 4> arg_types;
llvm::SmallVector<mlir::Type, 4> ret_types;
arg_types.reserve(specs.inputs.size());
ret_types.reserve(specs.output_arrays.size());
mlir::Builder builder(context);
// Input nodes as function arguments.
for (const auto& input : specs.inputs) {
mlir::Type element_type;
const auto& node_info = input.second;
TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype,
builder, &element_type));
llvm::SmallVector<int64_t, 4> shape;
TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape));
arg_types.push_back(mlir::RankedTensorType::get(shape, element_type));
}
// Output nodes as function returns.
for (const auto& ret : *ret_nodes) {
if (ret.node->num_outputs() <= ret.index) {
return errors::InvalidArgument("Invalid output index ", ret.index,