Add python wrapper mlir.experimental.convert_function
for importing ConcreteFunctions into TF MLIR.
This takes a ConcreteFunction, collects a FunctionDef for the function and an associated FunctionDefLibrary, and imports the FunctionDef and FunctionDefLibrary via `ConvertFunctionToMlir`. Control rets/target nodes of the entry function are also now supported in `ConvertFunctionToMlir`. PiperOrigin-RevId: 331195841 Change-Id: Ib3a7264e90ca303ab7a850bf18c8d5e330063a4f
This commit is contained in:
parent
8a080f6493
commit
ba167d161e
@ -37,6 +37,8 @@ cc_library(
|
|||||||
"@llvm-project//mlir:Parser",
|
"@llvm-project//mlir:Parser",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/InitAllPasses.h" // from @llvm-project
|
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||||
#include "mlir/Parser.h" // from @llvm-project
|
#include "mlir/Parser.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
@ -28,9 +29,40 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
|
||||||
|
// empty.
|
||||||
|
std::string RunPassPipelineOnModule(mlir::ModuleOp module,
|
||||||
|
const std::string &pass_pipeline,
|
||||||
|
TF_Status *status) {
|
||||||
|
if (!pass_pipeline.empty()) {
|
||||||
|
mlir::PassManager pm(module.getContext());
|
||||||
|
std::string error;
|
||||||
|
llvm::raw_string_ostream error_stream(error);
|
||||||
|
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
|
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
|
||||||
|
if (failed(pm.run(module))) {
|
||||||
|
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MlirModuleToString(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
std::string ImportGraphDef(const std::string &proto,
|
std::string ImportGraphDef(const std::string &proto,
|
||||||
const std::string &pass_pipeline,
|
const std::string &pass_pipeline,
|
||||||
TF_Status *status) {
|
TF_Status *status) {
|
||||||
@ -49,24 +81,43 @@ std::string ImportGraphDef(const std::string &proto,
|
|||||||
return "// error";
|
return "// error";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the pass_pipeline on the module if not empty.
|
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
|
||||||
if (!pass_pipeline.empty()) {
|
}
|
||||||
mlir::PassManager pm(&context);
|
|
||||||
std::string error;
|
std::string ImportFunction(const std::string &functiondef_proto,
|
||||||
llvm::raw_string_ostream error_stream(error);
|
const std::string &functiondef_library_proto,
|
||||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
const std::string &pass_pipeline,
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
TF_Status *status) {
|
||||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
FunctionDef functiondef;
|
||||||
|
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
|
||||||
|
if (!s.ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, s);
|
||||||
return "// error";
|
return "// error";
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
|
FunctionDefLibrary fdef_lib;
|
||||||
if (failed(pm.run(*module.ValueOrDie()))) {
|
s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib);
|
||||||
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
|
if (!s.ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, s);
|
||||||
return "// error";
|
return "// error";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||||
|
s = flib_def.AddFunctionDef(functiondef);
|
||||||
|
if (!s.ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, s);
|
||||||
|
return "// error";
|
||||||
}
|
}
|
||||||
return MlirModuleToString(*module.ConsumeValueOrDie());
|
|
||||||
|
const std::string &function_name = functiondef.signature().name();
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
|
||||||
|
if (!module.ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, module.status());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
|
||||||
|
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ExperimentalConvertSavedModelToMlir(
|
std::string ExperimentalConvertSavedModelToMlir(
|
||||||
|
@ -25,13 +25,23 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
||||||
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
|
// Load a GraphDef (binary or textual proto format), convert to MLIR, and
|
||||||
// returning it as a string.
|
// (optionally) optimize the module before returning it as a string.
|
||||||
// This is an early experimental API, ideally we should return a wrapper object
|
// This is an early experimental API, ideally we should return a wrapper object
|
||||||
// around a Python binding to the MLIR module.
|
// around a Python binding to the MLIR module.
|
||||||
std::string ImportGraphDef(const std::string &proto,
|
std::string ImportGraphDef(const std::string &proto,
|
||||||
const std::string &pass_pipeline, TF_Status *status);
|
const std::string &pass_pipeline, TF_Status *status);
|
||||||
|
|
||||||
|
// Simple wrapper to support tf.mlir.experimental.convert_function.
|
||||||
|
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
|
||||||
|
// convert to MLIR, and (optionally) optimize the module before returning it as
|
||||||
|
// a string.
|
||||||
|
// This is an early experimental API, ideally we should return a wrapper object
|
||||||
|
// around a Python binding to the MLIR module.
|
||||||
|
std::string ImportFunction(const std::string &functiondef_proto,
|
||||||
|
const std::string &functiondef_library_proto,
|
||||||
|
const std::string &pass_pipeline, TF_Status *status);
|
||||||
|
|
||||||
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
||||||
//
|
//
|
||||||
// Args:
|
// Args:
|
||||||
|
@ -3654,6 +3654,8 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
|||||||
tensorflow::GraphDebugInfo dummy_debug_info;
|
tensorflow::GraphDebugInfo dummy_debug_info;
|
||||||
tensorflow::GraphImportConfig specs;
|
tensorflow::GraphImportConfig specs;
|
||||||
specs.graph_as_function = true;
|
specs.graph_as_function = true;
|
||||||
|
for (const auto* control_ret_node : fbody->control_ret_nodes)
|
||||||
|
specs.control_outputs.push_back(control_ret_node->name());
|
||||||
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
||||||
flib_def, specs, name);
|
flib_def, specs, name);
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_mlir",
|
"//tensorflow/python:pywrap_mlir",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:tf_export",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,6 +22,10 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":mlir",
|
":mlir",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:logging_ops",
|
||||||
|
"//tensorflow/python:tensor_spec",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,9 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||||
"""Import a GraphDef and convert it to a textual MLIR module.
|
"""Import a GraphDef and convert it to a textual MLIR module.
|
||||||
|
|
||||||
|
This API is only intended for inspecting the internals of TensorFlow and the
|
||||||
|
string returned is at the moment intended for debugging purposes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_def: An object of type graph_pb2.GraphDef or a textual proto
|
graph_def: An object of type graph_pb2.GraphDef or a textual proto
|
||||||
representation of a valid GraphDef.
|
representation of a valid GraphDef.
|
||||||
@ -35,7 +38,51 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A textual representation of the MLIR module corresponding to the graphdef.
|
A textual representation of the MLIR module corresponding to the graphdef.
|
||||||
Raises a RuntimeError on error.
|
|
||||||
|
Raises:
|
||||||
|
InvalidArgumentError: if graph_def is invalid or cannot be converted to
|
||||||
|
MLIR.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
|
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('mlir.experimental.convert_function')
|
||||||
|
def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
|
||||||
|
"""Import a ConcreteFunction and convert it to a textual MLIR module.
|
||||||
|
|
||||||
|
This API is only intended for inspecting the internals of TensorFlow and the
|
||||||
|
string returned is at the moment intended for debugging purposes.
|
||||||
|
|
||||||
|
A [tf.function](https://www.tensorflow.org/api_docs/python/tf/function) can be
|
||||||
|
imported and converted from TensorFlow to TensorFlow MLIR with this API by
|
||||||
|
extracting its ConcreteFunction (eagerly-executing wrapper around a
|
||||||
|
[tf.Graph](https://www.tensorflow.org/api_docs/python/tf/Graph)).
|
||||||
|
|
||||||
|
For example:
|
||||||
|
>>> @tf.function
|
||||||
|
... def add(a, b):
|
||||||
|
... return a + b
|
||||||
|
|
||||||
|
>>> concrete_function = add.get_concrete_function(
|
||||||
|
... tf.TensorSpec(None, tf.dtypes.float32),
|
||||||
|
... tf.TensorSpec(None, tf.dtypes.float32))
|
||||||
|
>>> tf.mlir.experimental.convert_function(concrete_function)
|
||||||
|
'...module attributes {...} {...}'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
concrete_function: An object of type ConcreteFunction.
|
||||||
|
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
|
||||||
|
module, see MLIR documentation for the
|
||||||
|
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A textual representation of the MLIR module corresponding to the
|
||||||
|
ConcreteFunction.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidArgumentError: if concrete_function is invalid or cannot be converted
|
||||||
|
to MLIR.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return pywrap_mlir.import_function(concrete_function, pass_pipeline)
|
||||||
|
@ -19,23 +19,68 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.compiler.mlir import mlir
|
from tensorflow.python.compiler.mlir import mlir
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class MLIRImportTest(test.TestCase):
|
class MLIRGraphDefImportTest(test.TestCase):
|
||||||
|
|
||||||
def test_import_graph_def(self):
|
def testImport(self):
|
||||||
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
|
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
|
||||||
mlir_module = mlir.convert_graph_def('')
|
mlir_module = mlir.convert_graph_def('')
|
||||||
# An empty graph should contain at least an empty main function.
|
# An empty graph should contain at least an empty main function.
|
||||||
self.assertIn('func @main', mlir_module)
|
self.assertIn('func @main', mlir_module)
|
||||||
|
|
||||||
def test_invalid_pbtxt(self):
|
def testInvalidPbtxt(self):
|
||||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
'Could not parse input proto'):
|
'Could not parse input proto'):
|
||||||
mlir.convert_graph_def('some invalid proto')
|
mlir.convert_graph_def('some invalid proto')
|
||||||
|
|
||||||
|
|
||||||
|
class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||||
|
|
||||||
|
def testImport(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def identity(i):
|
||||||
|
return i
|
||||||
|
|
||||||
|
concrete_function = identity.get_concrete_function(
|
||||||
|
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||||
|
mlir_module = mlir.convert_function(concrete_function)
|
||||||
|
self.assertRegex(mlir_module, r'func @.*identity.*\(')
|
||||||
|
|
||||||
|
def testImportWithCall(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def callee(i):
|
||||||
|
return i
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def caller(i):
|
||||||
|
return callee(i)
|
||||||
|
|
||||||
|
concrete_function = caller.get_concrete_function(
|
||||||
|
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||||
|
mlir_module = mlir.convert_function(concrete_function)
|
||||||
|
self.assertRegex(mlir_module, r'func @.*caller.*\(')
|
||||||
|
self.assertRegex(mlir_module, r'func @.*callee.*\(')
|
||||||
|
|
||||||
|
def testImportWithControlRet(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def logging():
|
||||||
|
logging_ops.print_v2('some message')
|
||||||
|
|
||||||
|
concrete_function = logging.get_concrete_function()
|
||||||
|
mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
|
||||||
|
self.assertRegex(mlir_module, r'tf\.PrintV2')
|
||||||
|
self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -31,6 +31,17 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
|
|||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("ImportFunction", [](const std::string &functiondef,
|
||||||
|
const std::string &functiondef_library,
|
||||||
|
const std::string &pass_pipeline) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
std::string output = tensorflow::ImportFunction(
|
||||||
|
functiondef, functiondef_library, pass_pipeline, status.get());
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
return output;
|
||||||
|
});
|
||||||
|
|
||||||
m.def("ExperimentalConvertSavedModelToMlir",
|
m.def("ExperimentalConvertSavedModelToMlir",
|
||||||
[](const std::string &saved_model_path,
|
[](const std::string &saved_model_path,
|
||||||
const std::string &exported_names, bool show_debug_info) {
|
const std::string &exported_names, bool show_debug_info) {
|
||||||
|
@ -29,6 +29,13 @@ def import_graphdef(graphdef, pass_pipeline):
|
|||||||
pass_pipeline.encode('utf-8'))
|
pass_pipeline.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
|
def import_function(concrete_function, pass_pipeline):
|
||||||
|
return ImportFunction(
|
||||||
|
str(concrete_function.function_def).encode('utf-8'),
|
||||||
|
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
|
||||||
|
pass_pipeline.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||||
show_debug_info):
|
show_debug_info):
|
||||||
return ExperimentalConvertSavedModelToMlir(
|
return ExperimentalConvertSavedModelToMlir(
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
path: "tensorflow.mlir.experimental"
|
path: "tensorflow.mlir.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "convert_function"
|
||||||
|
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "convert_graph_def"
|
name: "convert_graph_def"
|
||||||
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
path: "tensorflow.mlir.experimental"
|
path: "tensorflow.mlir.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "convert_function"
|
||||||
|
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "convert_graph_def"
|
name: "convert_graph_def"
|
||||||
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user