Expose control ret nodes when converting a MLIR module to TensorFlow Graph.
This adds an option to capture control ret nodes which then can be used later and not be accidentally pruned away. Users of control ret nodes may need to remap nodes as control ret node names are not preserved. This is necessary as Graph does not store information defining which nodes are control ret nodes. PiperOrigin-RevId: 292199204 Change-Id: I4074cc6b2d514ff89435b9f419a71f7417526ab0
This commit is contained in:
parent
7f47fe01e4
commit
5177a9ff11
@ -184,7 +184,8 @@ class Exporter {
|
||||
// converted to the library functions in that graph.
|
||||
static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def);
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes);
|
||||
|
||||
// Converts a given FuncOp to a FunctionDef and adds it to the function
|
||||
// definition library
|
||||
@ -790,7 +791,8 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
|
||||
Status Exporter::Convert(mlir::ModuleOp module,
|
||||
const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def) {
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes) {
|
||||
mlir::Identifier entry_func_id =
|
||||
mlir::Identifier::get("main", module.getContext());
|
||||
absl::optional<mlir::FuncOp> entry_func;
|
||||
@ -812,10 +814,9 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
||||
return errors::FailedPrecondition("entry function `main` must be present");
|
||||
|
||||
// Updates the graph and the function library definition.
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
|
||||
&control_ret_nodes));
|
||||
control_ret_nodes));
|
||||
for (auto& func_def : flib.function()) {
|
||||
TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
|
||||
}
|
||||
@ -829,9 +830,19 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
||||
Status ConvertMlirToGraph(mlir::ModuleOp module,
|
||||
const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def) {
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes) {
|
||||
TF_RETURN_IF_ERROR(HasSingleGraphSingleOpIslandsFunctions(module));
|
||||
return Exporter::Convert(module, configs, graph, flib_def);
|
||||
return Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes);
|
||||
}
|
||||
|
||||
Status ConvertMlirToGraph(mlir::ModuleOp module,
|
||||
const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def) {
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
return ConvertMlirToGraph(module, configs, graph, flib_def,
|
||||
&control_ret_nodes);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
@ -34,6 +35,15 @@ using stream_executor::port::StatusOr;
|
||||
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs);
|
||||
|
||||
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
||||
// The "main" function of the module is stored in the graph and the rest of
|
||||
// functions are stored in the library. Control ret nodes are stored separately
|
||||
// in `control_ret_nodes`.
|
||||
stream_executor::port::Status ConvertMlirToGraph(
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
absl::flat_hash_set<Node*>* control_ret_nodes);
|
||||
|
||||
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
||||
// The "main" function of the module is stored in the graph and the rest of
|
||||
// functions are stored in the library.
|
||||
|
Loading…
x
Reference in New Issue
Block a user