From 4a78601139b9b331981ccbd8cbb9f2c248d47b3e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 25 Sep 2019 19:13:44 -0700 Subject: [PATCH] Rename import/export configs for Graph to MLIR consistently NodeSpecs -> GraphImportConfig ExporterConfigs -> GraphExportConfig PiperOrigin-RevId: 271257644 --- .../lite/python/graphdef_to_tfl_flatbuffer.cc | 2 +- tensorflow/compiler/mlir/python/mlir.i | 2 +- .../transforms/tf_graph_optimization_pass.cc | 4 +-- .../tensorflow/translate/export_graphdef.cc | 18 +++++++----- .../tensorflow/translate/export_graphdef.h | 4 +-- .../mlir/tensorflow/translate/import_model.cc | 29 ++++++++++--------- .../mlir/tensorflow/translate/import_model.h | 4 +-- .../translate/mlir_roundtrip_flags.cc | 4 +-- .../translate/mlir_roundtrip_flags.h | 9 +++--- .../translate/mlir_roundtrip_pass.cc | 4 +-- .../tensorflow/translate/tf_mlir_translate.cc | 2 +- .../tf_mlir_translate_registration.cc | 2 +- tensorflow/compiler/tf2xla/BUILD | 1 + .../compiler/tf2xla/mlir_bridge_pass.cc | 5 ++-- 14 files changed, 47 insertions(+), 43 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index e3364fb6e5e..3b3ed51b782 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -90,7 +90,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, const GraphDef& input, string* result) { mlir::MLIRContext context; - NodeSpecs specs; + GraphImportConfig specs; // Parse input arrays. std::vector node_names; diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i index 8ae893f6808..c1ec20788b6 100644 --- a/tensorflow/compiler/mlir/python/mlir.i +++ b/tensorflow/compiler/mlir/python/mlir.i @@ -40,7 +40,7 @@ string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Statu return "// error"; } GraphDebugInfo debug_info; - NodeSpecs specs; + GraphImportConfig specs; mlir::MLIRContext context; auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); if (!module.ok()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 830103f1c80..4b74f3e6ca3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -59,7 +59,7 @@ void GraphOptPass::runOnModule() { // Convert MLIR to Graph FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); - ExporterConfigs confs; + GraphExportConfig confs; auto graph = absl::make_unique(flib_def); Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def); if (!status.ok()) { @@ -90,7 +90,7 @@ void GraphOptPass::runOnModule() { // Convert Graph to MLIR GraphDebugInfo debug_info; - NodeSpecs specs; + GraphImportConfig specs; auto module_or_status = ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx); if (!module_or_status.ok()) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index a0e04c05ab8..83e663e4cb9 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -107,13 +107,13 @@ class Exporter { // one entry function, which is identified by name "main". This entry function // is converted to the base of the graph graph. The rest of the functions are // converted to the library functions in that graph. - static Status Convert(mlir::ModuleOp module, const ExporterConfigs& configs, + static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def); // Converts a given FuncOp to a FunctionDef and adds it to the function // definition library - static Status ConvertLibFunction(const ExporterConfigs& configs, + static Status ConvertLibFunction(const GraphExportConfig& configs, const Dialect* tf_dialect, mlir::FuncOp function, FunctionDefLibrary* flib); @@ -122,7 +122,7 @@ class Exporter { // Later on, this graph can be converted a function definition and added to // another graph. static StatusOr> Convert( - const ExporterConfigs& configs, const Dialect* tf_dialect, + const GraphExportConfig& configs, const Dialect* tf_dialect, mlir::FuncOp function, FunctionDefLibrary* flib); private: @@ -377,7 +377,7 @@ Status Exporter::AddNextIterationNode(Operation* inst) { } StatusOr> Exporter::Convert( - const ExporterConfigs& configs, const Dialect* tf_dialect, + const GraphExportConfig& configs, const Dialect* tf_dialect, mlir::FuncOp function, FunctionDefLibrary* flib) { if (function.getBlocks().size() != 1) { return errors::FailedPrecondition( @@ -502,7 +502,7 @@ StatusOr> Exporter::Convert( return graph; } -Status Exporter::ConvertLibFunction(const ExporterConfigs& configs, +Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, const Dialect* tf_dialect, mlir::FuncOp function, FunctionDefLibrary* flib) { @@ -560,7 +560,8 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs, return Status::OK(); } -Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs, +Status Exporter::Convert(mlir::ModuleOp module, + const GraphExportConfig& configs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { mlir::Identifier entry_func_id = @@ -596,7 +597,8 @@ Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs, } } // namespace -Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs, +Status ConvertMlirToGraph(mlir::ModuleOp module, + const GraphExportConfig& configs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { mlir::PassManager pass_manager(module.getContext()); @@ -609,7 +611,7 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs, } StatusOr> ConvertMlirToGraphdef( - mlir::ModuleOp module, const ExporterConfigs& configs) { + mlir::ModuleOp module, const GraphExportConfig& configs) { FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); auto graph = absl::make_unique(flib_def); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index 93061a95239..6920768b73d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -32,13 +32,13 @@ using stream_executor::port::StatusOr; // Given an MLIR module, returns a GraphDef. StatusOr> ConvertMlirToGraphdef( - mlir::ModuleOp module, const ExporterConfigs& configs); + mlir::ModuleOp module, const GraphExportConfig& configs); // Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition. // The "main" function of the module is stored in the graph and the rest of // functions are stored in the library. stream_executor::port::Status ConvertMlirToGraph( - mlir::ModuleOp module, const ExporterConfigs& confs, + mlir::ModuleOp module, const GraphExportConfig& confs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 4c03f2f6099..f6bf5b6a671 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -95,7 +95,7 @@ class ImporterBase { protected: explicit ImporterBase( const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const NodeSpecs& specs, mlir::ModuleOp module, + const GraphImportConfig& specs, mlir::ModuleOp module, std::unordered_map* tf_name_to_mlir_name) : builder_(module.getContext()), module_(module), @@ -289,7 +289,7 @@ class ImporterBase { mlir::MLIRContext* context_; std::unordered_map* tf_name_to_mlir_name_; const FunctionLibraryDefinition& graph_flib_; - const NodeSpecs& specs_; + const GraphImportConfig& specs_; const GraphDebugInfo& debug_info_; NodeValueMap node_values_; std::unique_ptr shape_refiner_; @@ -315,7 +315,7 @@ bool HasNonPrimaryOutputInUse(const GraphDef& graph_def, // is in use and therefore can not be replaced by the Placeholder node that only // has a single output. Status UpdateLegacyFedInputNode(const GraphDef& graph_def, - const NodeSpecs::InputArrays& inputs, + const GraphImportConfig::InputArrays& inputs, NodeDef* node) { const std::string& node_name = node->name(); auto it = inputs.find(node_name); @@ -345,7 +345,7 @@ Status UpdateLegacyFedInputNode(const GraphDef& graph_def, // the GraphDef. // - Replacing LegacyFedInput nodes with Placeholder nodes if // convert_legacy_fed_inputs option is enabled. -Status PreprocessGraphDef(const NodeSpecs* specs, GraphDef* graph_def) { +Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { const tensorflow::OpRegistrationData* op_reg_data; for (auto& node_def : *graph_def->mutable_node()) { // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One @@ -880,7 +880,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { // Converts the graph to a MLIR function and adds it to the module. // We populate the NodeSpec so that all the _Arg ops get their shape // added correctly. - NodeSpecs specs; + GraphImportConfig specs; for (const auto& name_and_value : func_def->attr()) { if (name_and_value.first == "_input_shapes") { auto& list = name_and_value.second.list(); @@ -1494,12 +1494,13 @@ class GraphDefImporter : public ImporterBase { static StatusOr Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs); + const FunctionLibraryDefinition& flib_def, + const GraphImportConfig& specs); private: explicit GraphDefImporter( const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const NodeSpecs& specs, mlir::ModuleOp module, + const GraphImportConfig& specs, mlir::ModuleOp module, std::unordered_map* tf_name_to_mlir_name) : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {} @@ -1509,7 +1510,7 @@ class GraphDefImporter : public ImporterBase { // information for the function returns are inferred by the shape refiner in // ImporterBase. StatusOr InferMainFunctionType( - const NodeSpecs& specs, mlir::MLIRContext* context, + const GraphImportConfig& specs, mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes); }; @@ -1517,7 +1518,7 @@ class GraphDefImporter : public ImporterBase { StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, - const NodeSpecs& specs) { + const GraphImportConfig& specs) { mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -1614,7 +1615,7 @@ StatusOr GraphDefImporter::Convert( } StatusOr GraphDefImporter::InferMainFunctionType( - const NodeSpecs& specs, mlir::MLIRContext* context, + const GraphImportConfig& specs, mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes) { // Finds out all the input nodes and output nodes. @@ -1707,7 +1708,7 @@ class SavedModelImporter : public ImporterBase { private: explicit SavedModelImporter( const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const NodeSpecs& specs, mlir::ModuleOp module, + const GraphImportConfig& specs, mlir::ModuleOp module, std::unordered_map* tf_name_to_mlir_name) : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {} }; @@ -1715,7 +1716,7 @@ class SavedModelImporter : public ImporterBase { StatusOr SavedModelImporter::Convert( const MetaGraphDef& meta_graph, const GraphDebugInfo& debug_info, bool add_default_attributes, mlir::MLIRContext* context) { - NodeSpecs specs; + GraphImportConfig specs; mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -1750,7 +1751,7 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { StatusOr ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, - const NodeSpecs& specs, mlir::MLIRContext* context, + const GraphImportConfig& specs, mlir::MLIRContext* context, bool add_default_attributes) { GraphConstructorOptions options; options.allow_internal_ops = true; @@ -1774,7 +1775,7 @@ StatusOr ConvertGraphdefToMlir( StatusOr ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, mlir::MLIRContext* context) { // TODO(jpienaar): Remove need to const_cast. if (specs.upgrade_legacy) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 6ca4c0098d7..49155121d4e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -34,14 +34,14 @@ namespace tensorflow { // tf_executor dialect. stream_executor::port::StatusOr ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, - const NodeSpecs& specs, mlir::MLIRContext* context, + const GraphImportConfig& specs, mlir::MLIRContext* context, bool add_default_attributes = true); // Given a Graph, returns a MLIR module containing the graph, expressed with // tf_executor dialect. stream_executor::port::StatusOr ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, mlir::MLIRContext* context); // Given a SavedModel, returns a MLIR module containing the functions, expressed diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 14b32224153..6e58caa358c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -82,7 +82,7 @@ Status ParseInputArrayInfo(absl::string_view array_names, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, - NodeSpecs::InputArrays* inputs) { + GraphImportConfig::InputArrays* inputs) { std::vector node_names = absl::StrSplit(array_names, ','); std::vector node_dtypes = absl::StrSplit(data_types, ','); @@ -134,7 +134,7 @@ Status ParseInputArrayInfo(const std::vector& node_names, DataType inference_type, const std::vector& node_mins, const std::vector& node_maxs, - NodeSpecs::InputArrays* inputs) { + GraphImportConfig::InputArrays* inputs) { if (node_names.size() != node_dtypes.size() || node_names.size() != node_shapes.size()) { return errors::FailedPrecondition( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 067591caf93..8a73ad9ef43 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -44,8 +44,7 @@ struct ArrayInfo { TensorShapeProto shape; }; -// TODO(jpienaar): Rename this the options in here are graph level too. -struct NodeSpecs { +struct GraphImportConfig { using InputArrays = llvm::MapVector>; // Maps input node names to node data types and shapes. @@ -69,7 +68,7 @@ struct NodeSpecs { bool upgrade_legacy = false; }; -struct ExporterConfigs { +struct GraphExportConfig { // Whether to export shape attribute for the NodeDefs in the GraphDef. bool export_shapes = true; // Whether to export library field in the GraphDef. @@ -103,7 +102,7 @@ Status ParseInputArrayInfo(absl::string_view array_names, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, - NodeSpecs::InputArrays* inputs); + GraphImportConfig::InputArrays* inputs); Status ParseInputArrayInfo(const std::vector& node_names, const std::vector& node_dtypes, @@ -111,7 +110,7 @@ Status ParseInputArrayInfo(const std::vector& node_names, DataType inference_type, const std::vector& node_mins, const std::vector& node_maxs, - NodeSpecs::InputArrays* inputs); + GraphImportConfig::InputArrays* inputs); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index f56c699d5ba..be5a95b57db 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -35,7 +35,7 @@ static StatusOr Import( MLIRContext* context) { // TODO(fengliuai): get debug info at runtime. GraphDebugInfo debug_info; - NodeSpecs specs; + GraphImportConfig specs; TF_ASSIGN_OR_RETURN( auto module, ConvertGraphToMlir(graph, debug_info, *options.flib_def, specs, context)); @@ -51,7 +51,7 @@ static StatusOr Import( static Status Export(mlir::OwningModuleRef module, const GraphOptimizationPassOptions& options, std::unique_ptr* graph) { - ExporterConfigs confs; + GraphExportConfig confs; return ConvertMlirToGraph(*module, confs, graph, options.flib_def); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 42b8e2afde3..16bc3bddcfa 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -57,7 +57,7 @@ static StatusOr GraphdefToMlirImport( TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info)); } - NodeSpecs specs; + GraphImportConfig specs; specs.prune_unused_nodes = prune_unused_nodes; specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs; specs.graph_as_function = graph_as_function; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 165d7e2e562..1fdc5cdea9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -69,7 +69,7 @@ static LogicalResult MlirToGraphdefTranslateFunction( if (!module) return failure(); // TODO(fengliuai): Add exporter flags. - tensorflow::ExporterConfigs confs; + tensorflow::GraphExportConfig confs; StatusOr> graphdef_or( tensorflow::ConvertMlirToGraphdef(module, confs)); if (!graphdef_or.status().ok()) { diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3c9e92ceee2..9ba2918209a 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -586,6 +586,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu_lib", "@llvm//:support", ], diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 4349eef3fbc..373be531e60 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/graph/graph_constructor.h" namespace tensorflow { @@ -86,8 +87,8 @@ Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) { } GraphDebugInfo debug_info; mlir::MLIRContext context; - NodeSpecs specs; - ExporterConfigs confs; + GraphImportConfig specs; + GraphExportConfig confs; TF_ASSIGN_OR_RETURN(auto module, ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, specs, &context));