From ba167d161e800af71f6d4d274394dd044937991f Mon Sep 17 00:00:00 2001 From: Andy Ly <lyandy@google.com> Date: Fri, 11 Sep 2020 12:08:28 -0700 Subject: [PATCH] Add python wrapper `mlir.experimental.convert_function` for importing ConcreteFunctions into TF MLIR. This takes a ConcreteFunction, collects a FunctionDef for the function and an associated FunctionDefLibrary, and imports the FunctionDef and FunctionDefLibrary via `ConvertFunctionToMlir`. Control rets/target nodes of the entry function are also now supported in `ConvertFunctionToMlir`. PiperOrigin-RevId: 331195841 Change-Id: Ib3a7264e90ca303ab7a850bf18c8d5e330063a4f --- tensorflow/compiler/mlir/python/BUILD | 2 + tensorflow/compiler/mlir/python/mlir.cc | 83 +++++++++++++++---- tensorflow/compiler/mlir/python/mlir.h | 14 +++- .../mlir/tensorflow/translate/import_model.cc | 2 + tensorflow/python/compiler/mlir/BUILD | 8 +- tensorflow/python/compiler/mlir/mlir.py | 49 ++++++++++- tensorflow/python/compiler/mlir/mlir_test.py | 51 +++++++++++- tensorflow/python/mlir_wrapper.cc | 11 +++ tensorflow/python/pywrap_mlir.py | 7 ++ .../v1/tensorflow.mlir.experimental.pbtxt | 4 + .../v2/tensorflow.mlir.experimental.pbtxt | 4 + 11 files changed, 211 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 6a47be332d0..66283bded71 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -37,6 +37,8 @@ cc_library( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 8bec288cda5..066726593a7 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -16,6 +16,7 @@ limitations under the License. #include <string> #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -28,9 +29,40 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op.h" namespace tensorflow { +namespace { + +// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not +// empty. +std::string RunPassPipelineOnModule(mlir::ModuleOp module, + const std::string &pass_pipeline, + TF_Status *status) { + if (!pass_pipeline.empty()) { + mlir::PassManager pm(module.getContext()); + std::string error; + llvm::raw_string_ostream error_stream(error); + if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Invalid pass_pipeline: " + error_stream.str()).c_str()); + return "// error"; + } + + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext()); + if (failed(pm.run(module))) { + Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); + return "// error"; + } + } + return MlirModuleToString(module); +} + +} // anonymous namespace + std::string ImportGraphDef(const std::string &proto, const std::string &pass_pipeline, TF_Status *status) { @@ -49,24 +81,43 @@ std::string ImportGraphDef(const std::string &proto, return "// error"; } - // Run the pass_pipeline on the module if not empty. - if (!pass_pipeline.empty()) { - mlir::PassManager pm(&context); - std::string error; - llvm::raw_string_ostream error_stream(error); - if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - ("Invalid pass_pipeline: " + error_stream.str()).c_str()); - return "// error"; - } + return RunPassPipelineOnModule(module->get(), pass_pipeline, status); +} - mlir::StatusScopedDiagnosticHandler statusHandler(&context); - if (failed(pm.run(*module.ValueOrDie()))) { - Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); - return "// error"; - } +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &functiondef_library_proto, + const std::string &pass_pipeline, + TF_Status *status) { + FunctionDef functiondef; + auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; } - return MlirModuleToString(*module.ConsumeValueOrDie()); + + FunctionDefLibrary fdef_lib; + s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + s = flib_def.AddFunctionDef(functiondef); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return "// error"; + } + + const std::string &function_name = functiondef.signature().name(); + mlir::MLIRContext context; + auto module = ConvertFunctionToMlir(function_name, flib_def, &context); + if (!module.ok()) { + Set_TF_Status_from_Status(status, module.status()); + return "// error"; + } + + return RunPassPipelineOnModule(module->get(), pass_pipeline, status); } std::string ExperimentalConvertSavedModelToMlir( diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index e68ac28124b..6133068a5e8 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -25,13 +25,23 @@ limitations under the License. namespace tensorflow { // Simple wrapper to support tf.mlir.experimental.convert_graph_def. -// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before -// returning it as a string. +// Load a GraphDef (binary or textual proto format), convert to MLIR, and +// (optionally) optimize the module before returning it as a string. // This is an early experimental API, ideally we should return a wrapper object // around a Python binding to the MLIR module. std::string ImportGraphDef(const std::string &proto, const std::string &pass_pipeline, TF_Status *status); +// Simple wrapper to support tf.mlir.experimental.convert_function. +// Load FunctionDef and FunctionDefLibrary (binary or textual proto format), +// convert to MLIR, and (optionally) optimize the module before returning it as +// a string. +// This is an early experimental API, ideally we should return a wrapper object +// around a Python binding to the MLIR module. +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &functiondef_library_proto, + const std::string &pass_pipeline, TF_Status *status); + // Load a SavedModel and return a textual MLIR string corresponding to it. // // Args: diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index b78f3112bdb..153c537589c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3654,6 +3654,8 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir( tensorflow::GraphDebugInfo dummy_debug_info; tensorflow::GraphImportConfig specs; specs.graph_as_function = true; + for (const auto* control_ret_node : fbody->control_ret_nodes) + specs.control_outputs.push_back(control_ret_node->name()); return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, flib_def, specs, name); } diff --git a/tensorflow/python/compiler/mlir/BUILD b/tensorflow/python/compiler/mlir/BUILD index fe59213837b..7e193795e60 100644 --- a/tensorflow/python/compiler/mlir/BUILD +++ b/tensorflow/python/compiler/mlir/BUILD @@ -11,7 +11,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:pywrap_mlir", - "//tensorflow/python:util", + "//tensorflow/python:tf_export", ], ) @@ -22,6 +22,10 @@ py_test( deps = [ ":mlir", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:logging_ops", + "//tensorflow/python:tensor_spec", + "//tensorflow/python/eager:def_function", ], ) diff --git a/tensorflow/python/compiler/mlir/mlir.py b/tensorflow/python/compiler/mlir/mlir.py index fd9918d19f8..3b72abc2850 100644 --- a/tensorflow/python/compiler/mlir/mlir.py +++ b/tensorflow/python/compiler/mlir/mlir.py @@ -26,6 +26,9 @@ from tensorflow.python.util.tf_export import tf_export def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'): """Import a GraphDef and convert it to a textual MLIR module. + This API is only intended for inspecting the internals of TensorFlow and the + string returned is at the moment intended for debugging purposes. + Args: graph_def: An object of type graph_pb2.GraphDef or a textual proto representation of a valid GraphDef. @@ -35,7 +38,51 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'): Returns: A textual representation of the MLIR module corresponding to the graphdef. - Raises a RuntimeError on error. + + Raises: + InvalidArgumentError: if graph_def is invalid or cannot be converted to + MLIR. """ return pywrap_mlir.import_graphdef(graph_def, pass_pipeline) + + +@tf_export('mlir.experimental.convert_function') +def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'): + """Import a ConcreteFunction and convert it to a textual MLIR module. + + This API is only intended for inspecting the internals of TensorFlow and the + string returned is at the moment intended for debugging purposes. + + A [tf.function](https://www.tensorflow.org/api_docs/python/tf/function) can be + imported and converted from TensorFlow to TensorFlow MLIR with this API by + extracting its ConcreteFunction (eagerly-executing wrapper around a + [tf.Graph](https://www.tensorflow.org/api_docs/python/tf/Graph)). + + For example: + >>> @tf.function + ... def add(a, b): + ... return a + b + + >>> concrete_function = add.get_concrete_function( + ... tf.TensorSpec(None, tf.dtypes.float32), + ... tf.TensorSpec(None, tf.dtypes.float32)) + >>> tf.mlir.experimental.convert_function(concrete_function) + '...module attributes {...} {...}' + + Args: + concrete_function: An object of type ConcreteFunction. + pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the + module, see MLIR documentation for the + [textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification). + + Returns: + A textual representation of the MLIR module corresponding to the + ConcreteFunction. + + Raises: + InvalidArgumentError: if concrete_function is invalid or cannot be converted + to MLIR. + + """ + return pywrap_mlir.import_function(concrete_function, pass_pipeline) diff --git a/tensorflow/python/compiler/mlir/mlir_test.py b/tensorflow/python/compiler/mlir/mlir_test.py index 2a2362d9f6b..9cb0063dc64 100644 --- a/tensorflow/python/compiler/mlir/mlir_test.py +++ b/tensorflow/python/compiler/mlir/mlir_test.py @@ -19,23 +19,68 @@ from __future__ import division from __future__ import print_function from tensorflow.python.compiler.mlir import mlir +from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import logging_ops from tensorflow.python.platform import test -class MLIRImportTest(test.TestCase): +class MLIRGraphDefImportTest(test.TestCase): - def test_import_graph_def(self): + def testImport(self): """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`.""" mlir_module = mlir.convert_graph_def('') # An empty graph should contain at least an empty main function. self.assertIn('func @main', mlir_module) - def test_invalid_pbtxt(self): + def testInvalidPbtxt(self): with self.assertRaisesRegex(errors.InvalidArgumentError, 'Could not parse input proto'): mlir.convert_graph_def('some invalid proto') +class MLIRConcreteFunctionImportTest(test.TestCase): + + def testImport(self): + + @def_function.function + def identity(i): + return i + + concrete_function = identity.get_concrete_function( + tensor_spec.TensorSpec(None, dtypes.float32)) + mlir_module = mlir.convert_function(concrete_function) + self.assertRegex(mlir_module, r'func @.*identity.*\(') + + def testImportWithCall(self): + + @def_function.function + def callee(i): + return i + + @def_function.function + def caller(i): + return callee(i) + + concrete_function = caller.get_concrete_function( + tensor_spec.TensorSpec(None, dtypes.float32)) + mlir_module = mlir.convert_function(concrete_function) + self.assertRegex(mlir_module, r'func @.*caller.*\(') + self.assertRegex(mlir_module, r'func @.*callee.*\(') + + def testImportWithControlRet(self): + + @def_function.function + def logging(): + logging_ops.print_v2('some message') + + concrete_function = logging.get_concrete_function() + mlir_module = mlir.convert_function(concrete_function, pass_pipeline='') + self.assertRegex(mlir_module, r'tf\.PrintV2') + self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control') + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 6bc0183fdc5..fa16e5872ee 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -31,6 +31,17 @@ PYBIND11_MODULE(_pywrap_mlir, m) { return output; }); + m.def("ImportFunction", [](const std::string &functiondef, + const std::string &functiondef_library, + const std::string &pass_pipeline) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + std::string output = tensorflow::ImportFunction( + functiondef, functiondef_library, pass_pipeline, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return output; + }); + m.def("ExperimentalConvertSavedModelToMlir", [](const std::string &saved_model_path, const std::string &exported_names, bool show_debug_info) { diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py index a8a8181ce48..82048140e16 100644 --- a/tensorflow/python/pywrap_mlir.py +++ b/tensorflow/python/pywrap_mlir.py @@ -29,6 +29,13 @@ def import_graphdef(graphdef, pass_pipeline): pass_pipeline.encode('utf-8')) +def import_function(concrete_function, pass_pipeline): + return ImportFunction( + str(concrete_function.function_def).encode('utf-8'), + str(concrete_function.graph.as_graph_def().library).encode('utf-8'), + pass_pipeline.encode('utf-8')) + + def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names, show_debug_info): return ExperimentalConvertSavedModelToMlir( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt index e268fcf8e73..7a140a13bc6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.mlir.experimental" tf_module { + member_method { + name: "convert_function" + argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], " + } member_method { name: "convert_graph_def" argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt index e268fcf8e73..7a140a13bc6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.mlir.experimental" tf_module { + member_method { + name: "convert_function" + argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], " + } member_method { name: "convert_graph_def" argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "