Add show_debug_info flag to MLIR testing functions

mlir.convert_* are experimental, testing interface that returns a textual version of input function/graph. Enable dumping location information as well.

PiperOrigin-RevId: 350450836
Change-Id: Ifc9d3377b7ceee186cd09b010ce4fd4371607e81
This commit is contained in:
Jacques Pienaar 2021-01-06 16:32:45 -08:00 committed by TensorFlower Gardener
parent a30df1baa1
commit 4de058963e
9 changed files with 54 additions and 30 deletions

View File

@ -103,12 +103,17 @@
value of `is_dynamic_op` is not True. We didn't use the value for
`max_batch_size` for building TensorRT engines.
* Issue a warning when function get_tensorrt_rewriter_config is used.
* Other:
* TF XLA
* Add new enum value `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` to
`tf.config.experimental.mlir_bridge_rollout` to enable a \"safe\" mode.
This runs the MLIR bridge only when an analysis of the graph only when
an analysis of the graph determines that it is safe to run.
* Other
* Adding show_debug_info to mlir.convert_graph_def and
mlir.convert_function.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/python/mlir.h"
#include <string>
#include "llvm/Support/raw_ostream.h"
@ -44,7 +46,7 @@ namespace {
// empty.
std::string RunPassPipelineOnModule(mlir::ModuleOp module,
const std::string &pass_pipeline,
TF_Status *status) {
bool show_debug_info, TF_Status *status) {
if (!pass_pipeline.empty()) {
mlir::PassManager pm(module.getContext());
std::string error;
@ -61,14 +63,14 @@ std::string RunPassPipelineOnModule(mlir::ModuleOp module,
return "// error";
}
}
return MlirModuleToString(module);
return MlirModuleToString(module, show_debug_info);
}
} // anonymous namespace
std::string ImportGraphDef(const std::string &proto,
const std::string &pass_pipeline,
TF_Status *status) {
bool show_debug_info, TF_Status *status) {
GraphDef graphdef;
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
if (!s.ok()) {
@ -84,13 +86,14 @@ std::string ImportGraphDef(const std::string &proto,
return "// error";
}
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
status);
}
std::string ImportFunction(const std::string &functiondef_proto,
const std::string &functiondef_library_proto,
const std::string &pass_pipeline,
TF_Status *status) {
bool show_debug_info, TF_Status *status) {
FunctionDef functiondef;
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
if (!s.ok()) {
@ -105,15 +108,16 @@ std::string ImportFunction(const std::string &functiondef_proto,
return "// error";
}
const std::string &function_name = functiondef.signature().name();
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
s = flib_def.AddFunctionDef(functiondef);
s = flib_def.AddFunctionDef(functiondef,
flib_def.GetStackTraces(function_name));
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
const std::string &function_name = functiondef.signature().name();
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
if (fdef == nullptr) {
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
@ -136,7 +140,8 @@ std::string ImportFunction(const std::string &functiondef_proto,
return "// error";
}
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
status);
}
std::string ExperimentalConvertSavedModelToMlir(

View File

@ -30,7 +30,8 @@ namespace tensorflow {
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
std::string ImportGraphDef(const std::string &proto,
const std::string &pass_pipeline, TF_Status *status);
const std::string &pass_pipeline,
bool show_debug_info, TF_Status *status);
// Simple wrapper to support tf.mlir.experimental.convert_function.
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
@ -40,7 +41,8 @@ std::string ImportGraphDef(const std::string &proto,
// around a Python binding to the MLIR module.
std::string ImportFunction(const std::string &functiondef_proto,
const std::string &functiondef_library_proto,
const std::string &pass_pipeline, TF_Status *status);
const std::string &pass_pipeline,
bool show_debug_info, TF_Status *status);
// Load a SavedModel and return a textual MLIR string corresponding to it.
//

View File

@ -23,7 +23,9 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('mlir.experimental.convert_graph_def')
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
def convert_graph_def(graph_def,
pass_pipeline='tf-standard-pipeline',
show_debug_info=False):
"""Import a GraphDef and convert it to a textual MLIR module.
This API is only intended for inspecting the internals of TensorFlow and the
@ -35,6 +37,7 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
module, see MLIR documentation for the
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
show_debug_info: Whether to include locations in the emitted textual form.
Returns:
A textual representation of the MLIR module corresponding to the graphdef.
@ -44,11 +47,13 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
MLIR.
"""
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline, show_debug_info)
@tf_export('mlir.experimental.convert_function')
def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
def convert_function(concrete_function,
pass_pipeline='tf-standard-pipeline',
show_debug_info=False):
"""Import a ConcreteFunction and convert it to a textual MLIR module.
This API is only intended for inspecting the internals of TensorFlow and the
@ -75,6 +80,7 @@ def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
module, see MLIR documentation for the
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
show_debug_info: Whether to include locations in the emitted textual form.
Returns:
A textual representation of the MLIR module corresponding to the
@ -85,4 +91,5 @@ def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
to MLIR.
"""
return pywrap_mlir.import_function(concrete_function, pass_pipeline)
return pywrap_mlir.import_function(concrete_function, pass_pipeline,
show_debug_info)

View File

@ -51,8 +51,9 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
concrete_function = sqr.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
mlir_module = mlir.convert_function(concrete_function)
mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
self.assertRegex(mlir_module, r'loc\(')
def testImportWithCall(self):

View File

@ -22,22 +22,25 @@ limitations under the License.
PYBIND11_MODULE(_pywrap_mlir, m) {
m.def("ImportGraphDef",
[](const std::string &graphdef, const std::string &pass_pipeline) {
[](const std::string &graphdef, const std::string &pass_pipeline,
bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output =
tensorflow::ImportGraphDef(graphdef, pass_pipeline, status.get());
std::string output = tensorflow::ImportGraphDef(
graphdef, pass_pipeline, show_debug_info, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ImportFunction", [](const std::string &functiondef,
const std::string &functiondef_library,
const std::string &pass_pipeline) {
const std::string &pass_pipeline,
bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output = tensorflow::ImportFunction(
functiondef, functiondef_library, pass_pipeline, status.get());
functiondef, functiondef_library, pass_pipeline, show_debug_info,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});

View File

@ -23,16 +23,17 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python._pywrap_mlir import *
def import_graphdef(graphdef, pass_pipeline):
def import_graphdef(graphdef, pass_pipeline, show_debug_info):
return ImportGraphDef(
str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'))
str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'),
show_debug_info)
def import_function(concrete_function, pass_pipeline):
def import_function(concrete_function, pass_pipeline, show_debug_info):
return ImportFunction(
str(concrete_function.function_def).encode('utf-8'),
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
pass_pipeline.encode('utf-8'))
pass_pipeline.encode('utf-8'), show_debug_info)
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,

View File

@ -2,10 +2,10 @@ path: "tensorflow.mlir.experimental"
tf_module {
member_method {
name: "convert_function"
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
argspec: "args=[\'concrete_function\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
}
member_method {
name: "convert_graph_def"
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
argspec: "args=[\'graph_def\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
}
}

View File

@ -2,10 +2,10 @@ path: "tensorflow.mlir.experimental"
tf_module {
member_method {
name: "convert_function"
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
argspec: "args=[\'concrete_function\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
}
member_method {
name: "convert_graph_def"
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
argspec: "args=[\'graph_def\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
}
}