Rename import/export configs for Graph to MLIR consistently

NodeSpecs -> GraphImportConfig
ExporterConfigs -> GraphExportConfig

PiperOrigin-RevId: 271257644
This commit is contained in:
Jacques Pienaar 2019-09-25 19:13:44 -07:00 committed by TensorFlower Gardener
parent a9ac8ebb64
commit 4a78601139
14 changed files with 47 additions and 43 deletions

View File

@ -90,7 +90,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
NodeSpecs specs;
GraphImportConfig specs;
// Parse input arrays.
std::vector<string> node_names;

View File

@ -40,7 +40,7 @@ string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Statu
return "// error";
}
GraphDebugInfo debug_info;
NodeSpecs specs;
GraphImportConfig specs;
mlir::MLIRContext context;
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {

View File

@ -59,7 +59,7 @@ void GraphOptPass::runOnModule() {
// Convert MLIR to Graph
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
FunctionDefLibrary());
ExporterConfigs confs;
GraphExportConfig confs;
auto graph = absl::make_unique<Graph>(flib_def);
Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def);
if (!status.ok()) {
@ -90,7 +90,7 @@ void GraphOptPass::runOnModule() {
// Convert Graph to MLIR
GraphDebugInfo debug_info;
NodeSpecs specs;
GraphImportConfig specs;
auto module_or_status =
ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx);
if (!module_or_status.ok()) {

View File

@ -107,13 +107,13 @@ class Exporter {
// one entry function, which is identified by name "main". This entry function
// is converted to the base of the graph graph. The rest of the functions are
// converted to the library functions in that graph.
static Status Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def);
// Converts a given FuncOp to a FunctionDef and adds it to the function
// definition library
static Status ConvertLibFunction(const ExporterConfigs& configs,
static Status ConvertLibFunction(const GraphExportConfig& configs,
const Dialect* tf_dialect,
mlir::FuncOp function,
FunctionDefLibrary* flib);
@ -122,7 +122,7 @@ class Exporter {
// Later on, this graph can be converted a function definition and added to
// another graph.
static StatusOr<std::unique_ptr<Graph>> Convert(
const ExporterConfigs& configs, const Dialect* tf_dialect,
const GraphExportConfig& configs, const Dialect* tf_dialect,
mlir::FuncOp function, FunctionDefLibrary* flib);
private:
@ -377,7 +377,7 @@ Status Exporter::AddNextIterationNode(Operation* inst) {
}
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
const ExporterConfigs& configs, const Dialect* tf_dialect,
const GraphExportConfig& configs, const Dialect* tf_dialect,
mlir::FuncOp function, FunctionDefLibrary* flib) {
if (function.getBlocks().size() != 1) {
return errors::FailedPrecondition(
@ -502,7 +502,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
return graph;
}
Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
const Dialect* tf_dialect,
mlir::FuncOp function,
FunctionDefLibrary* flib) {
@ -560,7 +560,8 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
return Status::OK();
}
Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
Status Exporter::Convert(mlir::ModuleOp module,
const GraphExportConfig& configs,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def) {
mlir::Identifier entry_func_id =
@ -596,7 +597,8 @@ Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
}
} // namespace
Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs,
Status ConvertMlirToGraph(mlir::ModuleOp module,
const GraphExportConfig& configs,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def) {
mlir::PassManager pass_manager(module.getContext());
@ -609,7 +611,7 @@ Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& configs,
}
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
mlir::ModuleOp module, const ExporterConfigs& configs) {
mlir::ModuleOp module, const GraphExportConfig& configs) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
FunctionDefLibrary());
auto graph = absl::make_unique<Graph>(flib_def);

View File

@ -32,13 +32,13 @@ using stream_executor::port::StatusOr;
// Given an MLIR module, returns a GraphDef.
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
mlir::ModuleOp module, const ExporterConfigs& configs);
mlir::ModuleOp module, const GraphExportConfig& configs);
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
// The "main" function of the module is stored in the graph and the rest of
// functions are stored in the library.
stream_executor::port::Status ConvertMlirToGraph(
mlir::ModuleOp module, const ExporterConfigs& confs,
mlir::ModuleOp module, const GraphExportConfig& confs,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
} // namespace tensorflow

View File

@ -95,7 +95,7 @@ class ImporterBase {
protected:
explicit ImporterBase(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const NodeSpecs& specs, mlir::ModuleOp module,
const GraphImportConfig& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
: builder_(module.getContext()),
module_(module),
@ -289,7 +289,7 @@ class ImporterBase {
mlir::MLIRContext* context_;
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
const FunctionLibraryDefinition& graph_flib_;
const NodeSpecs& specs_;
const GraphImportConfig& specs_;
const GraphDebugInfo& debug_info_;
NodeValueMap node_values_;
std::unique_ptr<ShapeRefiner> shape_refiner_;
@ -315,7 +315,7 @@ bool HasNonPrimaryOutputInUse(const GraphDef& graph_def,
// is in use and therefore can not be replaced by the Placeholder node that only
// has a single output.
Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
const NodeSpecs::InputArrays& inputs,
const GraphImportConfig::InputArrays& inputs,
NodeDef* node) {
const std::string& node_name = node->name();
auto it = inputs.find(node_name);
@ -345,7 +345,7 @@ Status UpdateLegacyFedInputNode(const GraphDef& graph_def,
// the GraphDef.
// - Replacing LegacyFedInput nodes with Placeholder nodes if
// convert_legacy_fed_inputs option is enabled.
Status PreprocessGraphDef(const NodeSpecs* specs, GraphDef* graph_def) {
Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
const tensorflow::OpRegistrationData* op_reg_data;
for (auto& node_def : *graph_def->mutable_node()) {
// TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One
@ -880,7 +880,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
// Converts the graph to a MLIR function and adds it to the module.
// We populate the NodeSpec so that all the _Arg ops get their shape
// added correctly.
NodeSpecs specs;
GraphImportConfig specs;
for (const auto& name_and_value : func_def->attr()) {
if (name_and_value.first == "_input_shapes") {
auto& list = name_and_value.second.list();
@ -1494,12 +1494,13 @@ class GraphDefImporter : public ImporterBase {
static StatusOr<mlir::OwningModuleRef> Convert(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs);
const FunctionLibraryDefinition& flib_def,
const GraphImportConfig& specs);
private:
explicit GraphDefImporter(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const NodeSpecs& specs, mlir::ModuleOp module,
const GraphImportConfig& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {}
@ -1509,7 +1510,7 @@ class GraphDefImporter : public ImporterBase {
// information for the function returns are inferred by the shape refiner in
// ImporterBase.
StatusOr<mlir::FunctionType> InferMainFunctionType(
const NodeSpecs& specs, mlir::MLIRContext* context,
const GraphImportConfig& specs, mlir::MLIRContext* context,
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes);
};
@ -1517,7 +1518,7 @@ class GraphDefImporter : public ImporterBase {
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
const NodeSpecs& specs) {
const GraphImportConfig& specs) {
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
@ -1614,7 +1615,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
}
StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
const NodeSpecs& specs, mlir::MLIRContext* context,
const GraphImportConfig& specs, mlir::MLIRContext* context,
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
// Finds out all the input nodes and output nodes.
@ -1707,7 +1708,7 @@ class SavedModelImporter : public ImporterBase {
private:
explicit SavedModelImporter(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const NodeSpecs& specs, mlir::ModuleOp module,
const GraphImportConfig& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {}
};
@ -1715,7 +1716,7 @@ class SavedModelImporter : public ImporterBase {
StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert(
const MetaGraphDef& meta_graph, const GraphDebugInfo& debug_info,
bool add_default_attributes, mlir::MLIRContext* context) {
NodeSpecs specs;
GraphImportConfig specs;
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
@ -1750,7 +1751,7 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {
StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
const GraphDef& graphdef, const GraphDebugInfo& debug_info,
const NodeSpecs& specs, mlir::MLIRContext* context,
const GraphImportConfig& specs, mlir::MLIRContext* context,
bool add_default_attributes) {
GraphConstructorOptions options;
options.allow_internal_ops = true;
@ -1774,7 +1775,7 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
const Graph& graph, const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
mlir::MLIRContext* context) {
// TODO(jpienaar): Remove need to const_cast.
if (specs.upgrade_legacy) {

View File

@ -34,14 +34,14 @@ namespace tensorflow {
// tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
const GraphDef& graphdef, const GraphDebugInfo& debug_info,
const NodeSpecs& specs, mlir::MLIRContext* context,
const GraphImportConfig& specs, mlir::MLIRContext* context,
bool add_default_attributes = true);
// Given a Graph, returns a MLIR module containing the graph, expressed with
// tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
const Graph& graph, const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs,
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
mlir::MLIRContext* context);
// Given a SavedModel, returns a MLIR module containing the functions, expressed

View File

@ -82,7 +82,7 @@ Status ParseInputArrayInfo(absl::string_view array_names,
absl::string_view inference_type,
absl::string_view min_values,
absl::string_view max_values,
NodeSpecs::InputArrays* inputs) {
GraphImportConfig::InputArrays* inputs) {
std::vector<string> node_names = absl::StrSplit(array_names, ',');
std::vector<string> node_dtypes = absl::StrSplit(data_types, ',');
@ -134,7 +134,7 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
DataType inference_type,
const std::vector<float>& node_mins,
const std::vector<float>& node_maxs,
NodeSpecs::InputArrays* inputs) {
GraphImportConfig::InputArrays* inputs) {
if (node_names.size() != node_dtypes.size() ||
node_names.size() != node_shapes.size()) {
return errors::FailedPrecondition(

View File

@ -44,8 +44,7 @@ struct ArrayInfo {
TensorShapeProto shape;
};
// TODO(jpienaar): Rename this the options in here are graph level too.
struct NodeSpecs {
struct GraphImportConfig {
using InputArrays =
llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
// Maps input node names to node data types and shapes.
@ -69,7 +68,7 @@ struct NodeSpecs {
bool upgrade_legacy = false;
};
struct ExporterConfigs {
struct GraphExportConfig {
// Whether to export shape attribute for the NodeDefs in the GraphDef.
bool export_shapes = true;
// Whether to export library field in the GraphDef.
@ -103,7 +102,7 @@ Status ParseInputArrayInfo(absl::string_view array_names,
absl::string_view inference_type,
absl::string_view min_values,
absl::string_view max_values,
NodeSpecs::InputArrays* inputs);
GraphImportConfig::InputArrays* inputs);
Status ParseInputArrayInfo(const std::vector<string>& node_names,
const std::vector<string>& node_dtypes,
@ -111,7 +110,7 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
DataType inference_type,
const std::vector<float>& node_mins,
const std::vector<float>& node_maxs,
NodeSpecs::InputArrays* inputs);
GraphImportConfig::InputArrays* inputs);
} // namespace tensorflow

View File

@ -35,7 +35,7 @@ static StatusOr<mlir::OwningModuleRef> Import(
MLIRContext* context) {
// TODO(fengliuai): get debug info at runtime.
GraphDebugInfo debug_info;
NodeSpecs specs;
GraphImportConfig specs;
TF_ASSIGN_OR_RETURN(
auto module,
ConvertGraphToMlir(graph, debug_info, *options.flib_def, specs, context));
@ -51,7 +51,7 @@ static StatusOr<mlir::OwningModuleRef> Import(
static Status Export(mlir::OwningModuleRef module,
const GraphOptimizationPassOptions& options,
std::unique_ptr<Graph>* graph) {
ExporterConfigs confs;
GraphExportConfig confs;
return ConvertMlirToGraph(*module, confs, graph, options.flib_def);
}

View File

@ -57,7 +57,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
}
NodeSpecs specs;
GraphImportConfig specs;
specs.prune_unused_nodes = prune_unused_nodes;
specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
specs.graph_as_function = graph_as_function;

View File

@ -69,7 +69,7 @@ static LogicalResult MlirToGraphdefTranslateFunction(
if (!module) return failure();
// TODO(fengliuai): Add exporter flags.
tensorflow::ExporterConfigs confs;
tensorflow::GraphExportConfig confs;
StatusOr<std::unique_ptr<tensorflow::GraphDef>> graphdef_or(
tensorflow::ConvertMlirToGraphdef(module, confs));
if (!graphdef_or.status().ok()) {

View File

@ -586,6 +586,7 @@ cc_library(
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu_lib",
"@llvm//:support",
],

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/graph/graph_constructor.h"
namespace tensorflow {
@ -86,8 +87,8 @@ Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
}
GraphDebugInfo debug_info;
mlir::MLIRContext context;
NodeSpecs specs;
ExporterConfigs confs;
GraphImportConfig specs;
GraphExportConfig confs;
TF_ASSIGN_OR_RETURN(auto module,
ConvertGraphToMlir(**options.graph, debug_info,
*options.flib_def, specs, &context));