Add python wrapper mlir.experimental.convert_function
for importing ConcreteFunctions into TF MLIR.
This takes a ConcreteFunction, collects a FunctionDef for the function and an associated FunctionDefLibrary, and imports the FunctionDef and FunctionDefLibrary via `ConvertFunctionToMlir`. Control rets/target nodes of the entry function are also now supported in `ConvertFunctionToMlir`. PiperOrigin-RevId: 331195841 Change-Id: Ib3a7264e90ca303ab7a850bf18c8d5e330063a4f
This commit is contained in:
parent
8a080f6493
commit
ba167d161e
tensorflow
compiler/mlir
python
tools/api/golden
@ -37,6 +37,8 @@ cc_library(
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
@ -28,9 +29,40 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
|
||||
// empty.
|
||||
std::string RunPassPipelineOnModule(mlir::ModuleOp module,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
if (!pass_pipeline.empty()) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
|
||||
if (failed(pm.run(module))) {
|
||||
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
return MlirModuleToString(module);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::string ImportGraphDef(const std::string &proto,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
@ -49,24 +81,43 @@ std::string ImportGraphDef(const std::string &proto,
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Run the pass_pipeline on the module if not empty.
|
||||
if (!pass_pipeline.empty()) {
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
|
||||
if (failed(pm.run(*module.ValueOrDie()))) {
|
||||
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
std::string ImportFunction(const std::string &functiondef_proto,
|
||||
const std::string &functiondef_library_proto,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
FunctionDef functiondef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
return MlirModuleToString(*module.ConsumeValueOrDie());
|
||||
|
||||
FunctionDefLibrary fdef_lib;
|
||||
s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
s = flib_def.AddFunctionDef(functiondef);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
const std::string &function_name = functiondef.signature().name();
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelToMlir(
|
||||
|
@ -25,13 +25,23 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
||||
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
|
||||
// returning it as a string.
|
||||
// Load a GraphDef (binary or textual proto format), convert to MLIR, and
|
||||
// (optionally) optimize the module before returning it as a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
std::string ImportGraphDef(const std::string &proto,
|
||||
const std::string &pass_pipeline, TF_Status *status);
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_function.
|
||||
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
|
||||
// convert to MLIR, and (optionally) optimize the module before returning it as
|
||||
// a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
std::string ImportFunction(const std::string &functiondef_proto,
|
||||
const std::string &functiondef_library_proto,
|
||||
const std::string &pass_pipeline, TF_Status *status);
|
||||
|
||||
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
// Args:
|
||||
|
@ -3654,6 +3654,8 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
tensorflow::GraphDebugInfo dummy_debug_info;
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.graph_as_function = true;
|
||||
for (const auto* control_ret_node : fbody->control_ret_nodes)
|
||||
specs.control_outputs.push_back(control_ret_node->name());
|
||||
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
||||
flib_def, specs, name);
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_mlir",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
@ -22,6 +22,10 @@ py_test(
|
||||
deps = [
|
||||
":mlir",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:logging_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,9 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||
"""Import a GraphDef and convert it to a textual MLIR module.
|
||||
|
||||
This API is only intended for inspecting the internals of TensorFlow and the
|
||||
string returned is at the moment intended for debugging purposes.
|
||||
|
||||
Args:
|
||||
graph_def: An object of type graph_pb2.GraphDef or a textual proto
|
||||
representation of a valid GraphDef.
|
||||
@ -35,7 +38,51 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||
|
||||
Returns:
|
||||
A textual representation of the MLIR module corresponding to the graphdef.
|
||||
Raises a RuntimeError on error.
|
||||
|
||||
Raises:
|
||||
InvalidArgumentError: if graph_def is invalid or cannot be converted to
|
||||
MLIR.
|
||||
|
||||
"""
|
||||
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
|
||||
|
||||
|
||||
@tf_export('mlir.experimental.convert_function')
|
||||
def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
|
||||
"""Import a ConcreteFunction and convert it to a textual MLIR module.
|
||||
|
||||
This API is only intended for inspecting the internals of TensorFlow and the
|
||||
string returned is at the moment intended for debugging purposes.
|
||||
|
||||
A [tf.function](https://www.tensorflow.org/api_docs/python/tf/function) can be
|
||||
imported and converted from TensorFlow to TensorFlow MLIR with this API by
|
||||
extracting its ConcreteFunction (eagerly-executing wrapper around a
|
||||
[tf.Graph](https://www.tensorflow.org/api_docs/python/tf/Graph)).
|
||||
|
||||
For example:
|
||||
>>> @tf.function
|
||||
... def add(a, b):
|
||||
... return a + b
|
||||
|
||||
>>> concrete_function = add.get_concrete_function(
|
||||
... tf.TensorSpec(None, tf.dtypes.float32),
|
||||
... tf.TensorSpec(None, tf.dtypes.float32))
|
||||
>>> tf.mlir.experimental.convert_function(concrete_function)
|
||||
'...module attributes {...} {...}'
|
||||
|
||||
Args:
|
||||
concrete_function: An object of type ConcreteFunction.
|
||||
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
|
||||
module, see MLIR documentation for the
|
||||
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
|
||||
|
||||
Returns:
|
||||
A textual representation of the MLIR module corresponding to the
|
||||
ConcreteFunction.
|
||||
|
||||
Raises:
|
||||
InvalidArgumentError: if concrete_function is invalid or cannot be converted
|
||||
to MLIR.
|
||||
|
||||
"""
|
||||
return pywrap_mlir.import_function(concrete_function, pass_pipeline)
|
||||
|
@ -19,23 +19,68 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compiler.mlir import mlir
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MLIRImportTest(test.TestCase):
|
||||
class MLIRGraphDefImportTest(test.TestCase):
|
||||
|
||||
def test_import_graph_def(self):
|
||||
def testImport(self):
|
||||
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
|
||||
mlir_module = mlir.convert_graph_def('')
|
||||
# An empty graph should contain at least an empty main function.
|
||||
self.assertIn('func @main', mlir_module)
|
||||
|
||||
def test_invalid_pbtxt(self):
|
||||
def testInvalidPbtxt(self):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'Could not parse input proto'):
|
||||
mlir.convert_graph_def('some invalid proto')
|
||||
|
||||
|
||||
class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||
|
||||
def testImport(self):
|
||||
|
||||
@def_function.function
|
||||
def identity(i):
|
||||
return i
|
||||
|
||||
concrete_function = identity.get_concrete_function(
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
mlir_module = mlir.convert_function(concrete_function)
|
||||
self.assertRegex(mlir_module, r'func @.*identity.*\(')
|
||||
|
||||
def testImportWithCall(self):
|
||||
|
||||
@def_function.function
|
||||
def callee(i):
|
||||
return i
|
||||
|
||||
@def_function.function
|
||||
def caller(i):
|
||||
return callee(i)
|
||||
|
||||
concrete_function = caller.get_concrete_function(
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
mlir_module = mlir.convert_function(concrete_function)
|
||||
self.assertRegex(mlir_module, r'func @.*caller.*\(')
|
||||
self.assertRegex(mlir_module, r'func @.*callee.*\(')
|
||||
|
||||
def testImportWithControlRet(self):
|
||||
|
||||
@def_function.function
|
||||
def logging():
|
||||
logging_ops.print_v2('some message')
|
||||
|
||||
concrete_function = logging.get_concrete_function()
|
||||
mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
|
||||
self.assertRegex(mlir_module, r'tf\.PrintV2')
|
||||
self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -31,6 +31,17 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ImportFunction", [](const std::string &functiondef,
|
||||
const std::string &functiondef_library,
|
||||
const std::string &pass_pipeline) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output = tensorflow::ImportFunction(
|
||||
functiondef, functiondef_library, pass_pipeline, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelToMlir",
|
||||
[](const std::string &saved_model_path,
|
||||
const std::string &exported_names, bool show_debug_info) {
|
||||
|
@ -29,6 +29,13 @@ def import_graphdef(graphdef, pass_pipeline):
|
||||
pass_pipeline.encode('utf-8'))
|
||||
|
||||
|
||||
def import_function(concrete_function, pass_pipeline):
|
||||
return ImportFunction(
|
||||
str(concrete_function.function_def).encode('utf-8'),
|
||||
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'))
|
||||
|
||||
|
||||
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelToMlir(
|
||||
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.mlir.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "convert_function"
|
||||
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert_graph_def"
|
||||
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.mlir.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "convert_function"
|
||||
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert_graph_def"
|
||||
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user