Add python wrapper mlir.experimental.convert_function for importing ConcreteFunctions into TF MLIR.

This takes a ConcreteFunction, collects a FunctionDef for the function and an associated FunctionDefLibrary, and imports the FunctionDef and FunctionDefLibrary via `ConvertFunctionToMlir`.
Control rets/target nodes of the entry function are also now supported in `ConvertFunctionToMlir`.

PiperOrigin-RevId: 331195841
Change-Id: Ib3a7264e90ca303ab7a850bf18c8d5e330063a4f
This commit is contained in:
Andy Ly 2020-09-11 12:08:28 -07:00 committed by TensorFlower Gardener
parent 8a080f6493
commit ba167d161e
11 changed files with 211 additions and 24 deletions

View File

@ -37,6 +37,8 @@ cc_library(
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
@ -28,9 +29,40 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
namespace {
// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
// empty.
std::string RunPassPipelineOnModule(mlir::ModuleOp module,
const std::string &pass_pipeline,
TF_Status *status) {
if (!pass_pipeline.empty()) {
mlir::PassManager pm(module.getContext());
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
if (failed(pm.run(module))) {
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
return "// error";
}
}
return MlirModuleToString(module);
}
} // anonymous namespace
std::string ImportGraphDef(const std::string &proto,
const std::string &pass_pipeline,
TF_Status *status) {
@ -49,24 +81,43 @@ std::string ImportGraphDef(const std::string &proto,
return "// error";
}
// Run the pass_pipeline on the module if not empty.
if (!pass_pipeline.empty()) {
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
}
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
if (failed(pm.run(*module.ValueOrDie()))) {
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
return "// error";
}
std::string ImportFunction(const std::string &functiondef_proto,
const std::string &functiondef_library_proto,
const std::string &pass_pipeline,
TF_Status *status) {
FunctionDef functiondef;
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
return MlirModuleToString(*module.ConsumeValueOrDie());
FunctionDefLibrary fdef_lib;
s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
s = flib_def.AddFunctionDef(functiondef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
const std::string &function_name = functiondef.signature().name();
mlir::MLIRContext context;
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";
}
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
}
std::string ExperimentalConvertSavedModelToMlir(

View File

@ -25,13 +25,23 @@ limitations under the License.
namespace tensorflow {
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
// returning it as a string.
// Load a GraphDef (binary or textual proto format), convert to MLIR, and
// (optionally) optimize the module before returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
std::string ImportGraphDef(const std::string &proto,
const std::string &pass_pipeline, TF_Status *status);
// Simple wrapper to support tf.mlir.experimental.convert_function.
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
// convert to MLIR, and (optionally) optimize the module before returning it as
// a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
std::string ImportFunction(const std::string &functiondef_proto,
const std::string &functiondef_library_proto,
const std::string &pass_pipeline, TF_Status *status);
// Load a SavedModel and return a textual MLIR string corresponding to it.
//
// Args:

View File

@ -3654,6 +3654,8 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
tensorflow::GraphDebugInfo dummy_debug_info;
tensorflow::GraphImportConfig specs;
specs.graph_as_function = true;
for (const auto* control_ret_node : fbody->control_ret_nodes)
specs.control_outputs.push_back(control_ret_node->name());
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
flib_def, specs, name);
}

View File

@ -11,7 +11,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:pywrap_mlir",
"//tensorflow/python:util",
"//tensorflow/python:tf_export",
],
)
@ -22,6 +22,10 @@ py_test(
deps = [
":mlir",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:logging_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function",
],
)

View File

@ -26,6 +26,9 @@ from tensorflow.python.util.tf_export import tf_export
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
"""Import a GraphDef and convert it to a textual MLIR module.
This API is only intended for inspecting the internals of TensorFlow and the
string returned is at the moment intended for debugging purposes.
Args:
graph_def: An object of type graph_pb2.GraphDef or a textual proto
representation of a valid GraphDef.
@ -35,7 +38,51 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
Returns:
A textual representation of the MLIR module corresponding to the graphdef.
Raises a RuntimeError on error.
Raises:
InvalidArgumentError: if graph_def is invalid or cannot be converted to
MLIR.
"""
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
@tf_export('mlir.experimental.convert_function')
def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
"""Import a ConcreteFunction and convert it to a textual MLIR module.
This API is only intended for inspecting the internals of TensorFlow and the
string returned is at the moment intended for debugging purposes.
A [tf.function](https://www.tensorflow.org/api_docs/python/tf/function) can be
imported and converted from TensorFlow to TensorFlow MLIR with this API by
extracting its ConcreteFunction (eagerly-executing wrapper around a
[tf.Graph](https://www.tensorflow.org/api_docs/python/tf/Graph)).
For example:
>>> @tf.function
... def add(a, b):
... return a + b
>>> concrete_function = add.get_concrete_function(
... tf.TensorSpec(None, tf.dtypes.float32),
... tf.TensorSpec(None, tf.dtypes.float32))
>>> tf.mlir.experimental.convert_function(concrete_function)
'...module attributes {...} {...}'
Args:
concrete_function: An object of type ConcreteFunction.
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
module, see MLIR documentation for the
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
Returns:
A textual representation of the MLIR module corresponding to the
ConcreteFunction.
Raises:
InvalidArgumentError: if concrete_function is invalid or cannot be converted
to MLIR.
"""
return pywrap_mlir.import_function(concrete_function, pass_pipeline)

View File

@ -19,23 +19,68 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import logging_ops
from tensorflow.python.platform import test
class MLIRImportTest(test.TestCase):
class MLIRGraphDefImportTest(test.TestCase):
def test_import_graph_def(self):
def testImport(self):
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
mlir_module = mlir.convert_graph_def('')
# An empty graph should contain at least an empty main function.
self.assertIn('func @main', mlir_module)
def test_invalid_pbtxt(self):
def testInvalidPbtxt(self):
with self.assertRaisesRegex(errors.InvalidArgumentError,
'Could not parse input proto'):
mlir.convert_graph_def('some invalid proto')
class MLIRConcreteFunctionImportTest(test.TestCase):
def testImport(self):
@def_function.function
def identity(i):
return i
concrete_function = identity.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
mlir_module = mlir.convert_function(concrete_function)
self.assertRegex(mlir_module, r'func @.*identity.*\(')
def testImportWithCall(self):
@def_function.function
def callee(i):
return i
@def_function.function
def caller(i):
return callee(i)
concrete_function = caller.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
mlir_module = mlir.convert_function(concrete_function)
self.assertRegex(mlir_module, r'func @.*caller.*\(')
self.assertRegex(mlir_module, r'func @.*callee.*\(')
def testImportWithControlRet(self):
@def_function.function
def logging():
logging_ops.print_v2('some message')
concrete_function = logging.get_concrete_function()
mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
self.assertRegex(mlir_module, r'tf\.PrintV2')
self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
if __name__ == '__main__':
test.main()

View File

@ -31,6 +31,17 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
return output;
});
m.def("ImportFunction", [](const std::string &functiondef,
const std::string &functiondef_library,
const std::string &pass_pipeline) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output = tensorflow::ImportFunction(
functiondef, functiondef_library, pass_pipeline, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalConvertSavedModelToMlir",
[](const std::string &saved_model_path,
const std::string &exported_names, bool show_debug_info) {

View File

@ -29,6 +29,13 @@ def import_graphdef(graphdef, pass_pipeline):
pass_pipeline.encode('utf-8'))
def import_function(concrete_function, pass_pipeline):
return ImportFunction(
str(concrete_function.function_def).encode('utf-8'),
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
pass_pipeline.encode('utf-8'))
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
show_debug_info):
return ExperimentalConvertSavedModelToMlir(

View File

@ -1,5 +1,9 @@
path: "tensorflow.mlir.experimental"
tf_module {
member_method {
name: "convert_function"
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
}
member_method {
name: "convert_graph_def"
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "

View File

@ -1,5 +1,9 @@
path: "tensorflow.mlir.experimental"
tf_module {
member_method {
name: "convert_function"
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
}
member_method {
name: "convert_graph_def"
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "