Do not add FunctionDefs to graph library each time a function is exported.

Exporter adds the existing FunctionDefs to the graph library each time an MLIR function is exported. This has a huge cost for exporting graphs with hundreds of functions. Instead, we can add them to the library only when we see a legacy call op.

PiperOrigin-RevId: 360989162
Change-Id: Ib834e8fd0124d6696ae25de1c7b22972066b20f8
This commit is contained in:
Prakalp Srivastava 2021-03-04 13:18:13 -08:00 committed by TensorFlower Gardener
parent 3d2714f0e7
commit 81adfef6d3

View File

@ -436,9 +436,6 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
graph->set_versions(versions);
}
// We have to add the function library here, so a custom operation, which is
// defined in the function library can be added to the graph.
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
Exporter exporter(graph.get(), tf_dialect);
auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
@ -509,6 +506,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
if (func != nullptr) {
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib,
visited_functions));
// TODO(prakalps): Optimize to only add the requested function to graph
// library rather than the all the functions exported so far.
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
}
return Status::OK();
@ -691,6 +690,11 @@ Status Exporter::Convert(mlir::ModuleOp module,
TF_ASSIGN_OR_RETURN(
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(),
&flib, visited_functions, control_ret_nodes));
// Add FunctionDefs and GradientDefs of MLIR functions to graph's function
// library. If duplicate FunctionDefs already exist (can happen if exporter
// had already added some FunctionDefs to the library to support legacy
// calls), they are ignored.
TF_RETURN_IF_ERROR(graph->get()->AddFunctionLibrary(flib));
}
for (auto& func_def : flib.function()) {
@ -734,8 +738,10 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
// Construct one in that case.
if (configs.export_entry_func_to_flib) {
graph = std::make_unique<Graph>(OpRegistry::Global());
// TODO(hinsu): Avoid Proto -> Memory -> Proto conversion here.
FunctionDefLibrary flib = flib_def.ToProto();
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib));
}
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib_def.ToProto()));
auto graphdef = absl::make_unique<GraphDef>();
graph->ToGraphDef(graphdef.get());