Add an option to export entry function to function library
Currently, main function is added to the given Graph and other functions are added to the function library. That's not desirable if we want to export MLIR module that was imported from a function. This change introduces a way to invert ConvertFunctionToMlir function. This is required to fall back to the old bridge for second phase of the compilation. PiperOrigin-RevId: 360398279 Change-Id: I7eaef32349c4b69c8eb3d0bfb929916794e34fc3
This commit is contained in:
parent
62dfa9e1bf
commit
e305de9662
@ -0,0 +1,20 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef -tf-export-entry-func-to-flib %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} {
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island wraps "tf.Const"() {device = "TPU:0", name = "const", dtype = "tfdtype$DT_INT32", value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-NOT: node
|
||||
|
||||
// CHECK: library
|
||||
// CHECK-NEXT: function
|
||||
// CHECK-NEXT: signature
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK: node_def
|
||||
// CHECK: op: "Const"
|
@ -667,7 +667,8 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
||||
if (function.isExternal())
|
||||
return errors::FailedPrecondition("External functions not supported");
|
||||
|
||||
if (function.getName() == entry_func_id) {
|
||||
if (function.getName() == entry_func_id &&
|
||||
!configs.export_entry_func_to_flib) {
|
||||
entry_func.emplace(function);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -675,13 +676,17 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
||||
}
|
||||
}
|
||||
|
||||
if (!entry_func.has_value())
|
||||
return errors::FailedPrecondition("entry function `main` must be present");
|
||||
if (!configs.export_entry_func_to_flib) {
|
||||
if (!entry_func.has_value())
|
||||
return errors::FailedPrecondition(
|
||||
"entry function `main` must be present");
|
||||
|
||||
// Updates the graph and the function library definition.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(),
|
||||
&flib, control_ret_nodes));
|
||||
}
|
||||
|
||||
// Updates the graph and the function library definition.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
|
||||
control_ret_nodes));
|
||||
for (auto& func_def : flib.function()) {
|
||||
TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
|
||||
}
|
||||
@ -716,8 +721,16 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs) {
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
||||
FunctionDefLibrary());
|
||||
auto graph = absl::make_unique<Graph>(flib_def);
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
|
||||
|
||||
// If the entry function is exported to flib, then no graph is constructed.
|
||||
// Construct one in that case.
|
||||
if (configs.export_entry_func_to_flib) {
|
||||
graph = std::make_unique<Graph>(OpRegistry::Global());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib_def.ToProto()));
|
||||
|
||||
auto graphdef = absl::make_unique<GraphDef>();
|
||||
graph->ToGraphDef(graphdef.get());
|
||||
if (!configs.export_library) graphdef->clear_library();
|
||||
|
@ -75,6 +75,9 @@ struct GraphExportConfig {
|
||||
bool export_library = true;
|
||||
// Whether to export debug original node name in the GraphDef.
|
||||
bool export_debug_info = true;
|
||||
// Whether to export the entry function to function library instead of the
|
||||
// graph.
|
||||
bool export_entry_func_to_flib = false;
|
||||
};
|
||||
|
||||
// Parses the command line flag strings to the specification of nodes in
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
using llvm::cl::opt;
|
||||
|
||||
// Import options.
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> input_arrays(
|
||||
"tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"),
|
||||
@ -115,3 +116,11 @@ opt<bool> enable_shape_inference(
|
||||
"tf-enable-shape-inference-on-import",
|
||||
llvm::cl::desc("Enable shape inference on import (temporary)"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// Export options.
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> export_entry_func_to_flib(
|
||||
"tf-export-entry-func-to-flib",
|
||||
llvm::cl::desc(
|
||||
"Export entry function to function library instead of graph"),
|
||||
llvm::cl::init(false));
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
|
||||
// Please see the implementation file for documentation of these options.
|
||||
|
||||
// Import options.
|
||||
extern llvm::cl::opt<std::string> input_arrays;
|
||||
extern llvm::cl::opt<std::string> input_dtypes;
|
||||
extern llvm::cl::opt<std::string> input_shapes;
|
||||
@ -42,4 +43,7 @@ extern llvm::cl::opt<bool> upgrade_legacy;
|
||||
// TODO(jpienaar): Temporary flag, flip default and remove.
|
||||
extern llvm::cl::opt<bool> enable_shape_inference;
|
||||
|
||||
// Export options.
|
||||
extern llvm::cl::opt<bool> export_entry_func_to_flib;
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_
|
||||
|
@ -75,6 +75,7 @@ static LogicalResult MlirToGraphdefTranslateFunction(
|
||||
|
||||
// TODO(fengliuai): Add exporter flags.
|
||||
tensorflow::GraphExportConfig confs;
|
||||
confs.export_entry_func_to_flib = export_entry_func_to_flib;
|
||||
StatusOr<std::unique_ptr<tensorflow::GraphDef>> graphdef_or(
|
||||
tensorflow::ConvertMlirToGraphdef(module, confs));
|
||||
if (!graphdef_or.status().ok()) {
|
||||
|
Loading…
Reference in New Issue
Block a user