Rename import/export configs for Graph to MLIR consistently
NodeSpecs -> GraphImportConfig ExporterConfigs -> GraphExportConfig PiperOrigin-RevId: 271257644
This commit is contained in:
parent
a9ac8ebb64
commit
4a78601139
@ -90,7 +90,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const GraphDef& input,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
|
@ -40,7 +40,7 @@ string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Statu
|
||||
return "// error";
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||
if (!module.ok()) {
|
||||
|
@ -59,7 +59,7 @@ void GraphOptPass::runOnModule() {
|
||||
// Convert MLIR to Graph
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
||||
FunctionDefLibrary());
|
||||
ExporterConfigs confs;
|
||||
GraphExportConfig confs;
|
||||
auto graph = absl::make_unique<Graph>(flib_def);
|
||||
Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def);
|
||||
if (!status.ok()) {
|
||||
@ -90,7 +90,7 @@ void GraphOptPass::runOnModule() {
|
||||
|
||||
// Convert Graph to MLIR
|
||||
GraphDebugInfo debug_info;
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
auto module_or_status =
|
||||
ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx);
|
||||
if (!module_or_status.ok()) {
|
||||
|
@ -107,13 +107,13 @@ class Exporter {
|
||||
// one entry function, which is identified by name "main". This entry function
|
||||
// is converted to the base of the graph graph. The rest of the functions are
|
||||
// converted to the library functions in that graph.
|
||||
static Status Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||
static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def);
|
||||
|
||||
// Converts a given FuncOp to a FunctionDef and adds it to the function
|
||||
// definition library
|
||||
static Status ConvertLibFunction(const ExporterConfigs& configs,
|
||||
static Status ConvertLibFunction(const GraphExportConfig& configs,
|
||||
const Dialect* tf_dialect,
|
||||
mlir::FuncOp function,
|
||||
FunctionDefLibrary* flib);
|
||||
@ -122,7 +122,7 @@ class Exporter {
|
||||
// Later on, this graph can be converted a function definition and added to
|
||||
// another graph.
|
||||
static StatusOr<std::unique_ptr<Graph>> Convert(
|
||||
const ExporterConfigs& configs, const Dialect* tf_dialect,
|
||||
const GraphExportConfig& configs, const Dialect* tf_dialect,
|
||||
mlir::FuncOp function, FunctionDefLibrary* flib);
|
||||
|
||||
private:
|
||||
@ -377,7 +377,7 @@ Status Exporter::AddNextIterationNode(Operation* inst) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
const ExporterConfigs& configs, const Dialect* tf_dialect,
|
||||
const GraphExportConfig& configs, const Dialect* tf_dialect,
|
||||
mlir::FuncOp function, FunctionDefLibrary* flib) {
|
||||
if (function.getBlocks().size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
@ -502,7 +502,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
return graph;
|
||||
}
|
||||
|
||||
Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
|
||||
Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
|
||||
const Dialect* tf_dialect,
|
||||
mlir::FuncOp function,
|
||||
FunctionDefLibrary* flib) {
|
||||
@ -560,7 +560,8 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||
Status Exporter::Convert(mlir::ModuleOp module,
|
||||
const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def) {
|
||||
mlir::Identifier entry_func_id =
|
||||
@ -596,7 +597,8 @@ Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||
Status ConvertMlirToGraph(mlir::ModuleOp module,
|
||||
const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def) {
|
||||
mlir::PassManager pass_manager(module.getContext());
|
||||
@ -609,7 +611,7 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
mlir::ModuleOp module, const ExporterConfigs& configs) {
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs) {
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
||||
FunctionDefLibrary());
|
||||
auto graph = absl::make_unique<Graph>(flib_def);
|
||||
|
@ -32,13 +32,13 @@ using stream_executor::port::StatusOr;
|
||||
|
||||
// Given an MLIR module, returns a GraphDef.
|
||||
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
mlir::ModuleOp module, const ExporterConfigs& configs);
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs);
|
||||
|
||||
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
||||
// The "main" function of the module is stored in the graph and the rest of
|
||||
// functions are stored in the library.
|
||||
stream_executor::port::Status ConvertMlirToGraph(
|
||||
mlir::ModuleOp module, const ExporterConfigs& confs,
|
||||
mlir::ModuleOp module, const GraphExportConfig& confs,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -95,7 +95,7 @@ class ImporterBase {
|
||||
protected:
|
||||
explicit ImporterBase(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const NodeSpecs& specs, mlir::ModuleOp module,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
|
||||
: builder_(module.getContext()),
|
||||
module_(module),
|
||||
@ -289,7 +289,7 @@ class ImporterBase {
|
||||
mlir::MLIRContext* context_;
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
|
||||
const FunctionLibraryDefinition& graph_flib_;
|
||||
const NodeSpecs& specs_;
|
||||
const GraphImportConfig& specs_;
|
||||
const GraphDebugInfo& debug_info_;
|
||||
NodeValueMap node_values_;
|
||||
std::unique_ptr<ShapeRefiner> shape_refiner_;
|
||||
@ -315,7 +315,7 @@ bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
|
||||
// is in use and therefore can not be replaced by the Placeholder node that only
|
||||
// has a single output.
|
||||
Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
|
||||
const NodeSpecs::InputArrays& inputs,
|
||||
const GraphImportConfig::InputArrays& inputs,
|
||||
NodeDef* node) {
|
||||
const std::string& node_name = node->name();
|
||||
auto it = inputs.find(node_name);
|
||||
@ -345,7 +345,7 @@ Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
|
||||
// the GraphDef.
|
||||
// - Replacing LegacyFedInput nodes with Placeholder nodes if
|
||||
// convert_legacy_fed_inputs option is enabled.
|
||||
Status PreprocessGraphDef(const NodeSpecs* specs, GraphDef* graph_def) {
|
||||
Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
for (auto& node_def : *graph_def->mutable_node()) {
|
||||
// TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
|
||||
@ -880,7 +880,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
// Converts the graph to a MLIR function and adds it to the module.
|
||||
// We populate the NodeSpec so that all the _Arg ops get their shape
|
||||
// added correctly.
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
for (const auto& name_and_value : func_def->attr()) {
|
||||
if (name_and_value.first == "_input_shapes") {
|
||||
auto& list = name_and_value.second.list();
|
||||
@ -1494,12 +1494,13 @@ class GraphDefImporter : public ImporterBase {
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs);
|
||||
const FunctionLibraryDefinition& flib_def,
|
||||
const GraphImportConfig& specs);
|
||||
|
||||
private:
|
||||
explicit GraphDefImporter(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const NodeSpecs& specs, mlir::ModuleOp module,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
|
||||
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {}
|
||||
|
||||
@ -1509,7 +1510,7 @@ class GraphDefImporter : public ImporterBase {
|
||||
// information for the function returns are inferred by the shape refiner in
|
||||
// ImporterBase.
|
||||
StatusOr<mlir::FunctionType> InferMainFunctionType(
|
||||
const NodeSpecs& specs, mlir::MLIRContext* context,
|
||||
const GraphImportConfig& specs, mlir::MLIRContext* context,
|
||||
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
|
||||
absl::InlinedVector<OutputTensor, 4>* ret_nodes);
|
||||
};
|
||||
@ -1517,7 +1518,7 @@ class GraphDefImporter : public ImporterBase {
|
||||
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||
const NodeSpecs& specs) {
|
||||
const GraphImportConfig& specs) {
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
@ -1614,7 +1615,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
}
|
||||
|
||||
StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
||||
const NodeSpecs& specs, mlir::MLIRContext* context,
|
||||
const GraphImportConfig& specs, mlir::MLIRContext* context,
|
||||
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
|
||||
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
|
||||
// Finds out all the input nodes and output nodes.
|
||||
@ -1707,7 +1708,7 @@ class SavedModelImporter : public ImporterBase {
|
||||
private:
|
||||
explicit SavedModelImporter(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const NodeSpecs& specs, mlir::ModuleOp module,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
|
||||
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {}
|
||||
};
|
||||
@ -1715,7 +1716,7 @@ class SavedModelImporter : public ImporterBase {
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert(
|
||||
const MetaGraphDef& meta_graph, const GraphDebugInfo& debug_info,
|
||||
bool add_default_attributes, mlir::MLIRContext* context) {
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
@ -1750,7 +1751,7 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
|
||||
const GraphDef& graphdef, const GraphDebugInfo& debug_info,
|
||||
const NodeSpecs& specs, mlir::MLIRContext* context,
|
||||
const GraphImportConfig& specs, mlir::MLIRContext* context,
|
||||
bool add_default_attributes) {
|
||||
GraphConstructorOptions options;
|
||||
options.allow_internal_ops = true;
|
||||
@ -1774,7 +1775,7 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
const Graph& graph, const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
|
||||
mlir::MLIRContext* context) {
|
||||
// TODO(jpienaar): Remove need to const_cast.
|
||||
if (specs.upgrade_legacy) {
|
||||
|
@ -34,14 +34,14 @@ namespace tensorflow {
|
||||
// tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
|
||||
const GraphDef& graphdef, const GraphDebugInfo& debug_info,
|
||||
const NodeSpecs& specs, mlir::MLIRContext* context,
|
||||
const GraphImportConfig& specs, mlir::MLIRContext* context,
|
||||
bool add_default_attributes = true);
|
||||
|
||||
// Given a Graph, returns a MLIR module containing the graph, expressed with
|
||||
// tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
const Graph& graph, const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Given a SavedModel, returns a MLIR module containing the functions, expressed
|
||||
|
@ -82,7 +82,7 @@ Status ParseInputArrayInfo(absl::string_view array_names,
|
||||
absl::string_view inference_type,
|
||||
absl::string_view min_values,
|
||||
absl::string_view max_values,
|
||||
NodeSpecs::InputArrays* inputs) {
|
||||
GraphImportConfig::InputArrays* inputs) {
|
||||
std::vector<string> node_names = absl::StrSplit(array_names, ',');
|
||||
std::vector<string> node_dtypes = absl::StrSplit(data_types, ',');
|
||||
|
||||
@ -134,7 +134,7 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
DataType inference_type,
|
||||
const std::vector<float>& node_mins,
|
||||
const std::vector<float>& node_maxs,
|
||||
NodeSpecs::InputArrays* inputs) {
|
||||
GraphImportConfig::InputArrays* inputs) {
|
||||
if (node_names.size() != node_dtypes.size() ||
|
||||
node_names.size() != node_shapes.size()) {
|
||||
return errors::FailedPrecondition(
|
||||
|
@ -44,8 +44,7 @@ struct ArrayInfo {
|
||||
TensorShapeProto shape;
|
||||
};
|
||||
|
||||
// TODO(jpienaar): Rename this the options in here are graph level too.
|
||||
struct NodeSpecs {
|
||||
struct GraphImportConfig {
|
||||
using InputArrays =
|
||||
llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
|
||||
// Maps input node names to node data types and shapes.
|
||||
@ -69,7 +68,7 @@ struct NodeSpecs {
|
||||
bool upgrade_legacy = false;
|
||||
};
|
||||
|
||||
struct ExporterConfigs {
|
||||
struct GraphExportConfig {
|
||||
// Whether to export shape attribute for the NodeDefs in the GraphDef.
|
||||
bool export_shapes = true;
|
||||
// Whether to export library field in the GraphDef.
|
||||
@ -103,7 +102,7 @@ Status ParseInputArrayInfo(absl::string_view array_names,
|
||||
absl::string_view inference_type,
|
||||
absl::string_view min_values,
|
||||
absl::string_view max_values,
|
||||
NodeSpecs::InputArrays* inputs);
|
||||
GraphImportConfig::InputArrays* inputs);
|
||||
|
||||
Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
const std::vector<string>& node_dtypes,
|
||||
@ -111,7 +110,7 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
DataType inference_type,
|
||||
const std::vector<float>& node_mins,
|
||||
const std::vector<float>& node_maxs,
|
||||
NodeSpecs::InputArrays* inputs);
|
||||
GraphImportConfig::InputArrays* inputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -35,7 +35,7 @@ static StatusOr<mlir::OwningModuleRef> Import(
|
||||
MLIRContext* context) {
|
||||
// TODO(fengliuai): get debug info at runtime.
|
||||
GraphDebugInfo debug_info;
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module,
|
||||
ConvertGraphToMlir(graph, debug_info, *options.flib_def, specs, context));
|
||||
@ -51,7 +51,7 @@ static StatusOr<mlir::OwningModuleRef> Import(
|
||||
static Status Export(mlir::OwningModuleRef module,
|
||||
const GraphOptimizationPassOptions& options,
|
||||
std::unique_ptr<Graph>* graph) {
|
||||
ExporterConfigs confs;
|
||||
GraphExportConfig confs;
|
||||
return ConvertMlirToGraph(*module, confs, graph, options.flib_def);
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
|
||||
}
|
||||
|
||||
NodeSpecs specs;
|
||||
GraphImportConfig specs;
|
||||
specs.prune_unused_nodes = prune_unused_nodes;
|
||||
specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
|
||||
specs.graph_as_function = graph_as_function;
|
||||
|
@ -69,7 +69,7 @@ static LogicalResult MlirToGraphdefTranslateFunction(
|
||||
if (!module) return failure();
|
||||
|
||||
// TODO(fengliuai): Add exporter flags.
|
||||
tensorflow::ExporterConfigs confs;
|
||||
tensorflow::GraphExportConfig confs;
|
||||
StatusOr<std::unique_ptr<tensorflow::GraphDef>> graphdef_or(
|
||||
tensorflow::ConvertMlirToGraphdef(module, confs));
|
||||
if (!graphdef_or.status().ok()) {
|
||||
|
@ -586,6 +586,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"@llvm//:support",
|
||||
],
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -86,8 +87,8 @@ Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
NodeSpecs specs;
|
||||
ExporterConfigs confs;
|
||||
GraphImportConfig specs;
|
||||
GraphExportConfig confs;
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ConvertGraphToMlir(**options.graph, debug_info,
|
||||
*options.flib_def, specs, &context));
|
||||
|
Loading…
Reference in New Issue
Block a user