Add show_debug_info flag to MLIR testing functions
mlir.convert_* are experimental, testing interface that returns a textual version of input function/graph. Enable dumping location information as well. PiperOrigin-RevId: 350450836 Change-Id: Ifc9d3377b7ceee186cd09b010ce4fd4371607e81
This commit is contained in:
parent
a30df1baa1
commit
4de058963e
@ -103,12 +103,17 @@
|
||||
value of `is_dynamic_op` is not True. We didn't use the value for
|
||||
`max_batch_size` for building TensorRT engines.
|
||||
* Issue a warning when function get_tensorrt_rewriter_config is used.
|
||||
* Other:
|
||||
|
||||
* TF XLA
|
||||
* Add new enum value `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` to
|
||||
`tf.config.experimental.mlir_bridge_rollout` to enable a \"safe\" mode.
|
||||
This runs the MLIR bridge only when an analysis of the graph only when
|
||||
an analysis of the graph determines that it is safe to run.
|
||||
|
||||
* Other
|
||||
* Adding show_debug_info to mlir.convert_graph_def and
|
||||
mlir.convert_function.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@ -44,7 +46,7 @@ namespace {
|
||||
// empty.
|
||||
std::string RunPassPipelineOnModule(mlir::ModuleOp module,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
if (!pass_pipeline.empty()) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
std::string error;
|
||||
@ -61,14 +63,14 @@ std::string RunPassPipelineOnModule(mlir::ModuleOp module,
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
return MlirModuleToString(module);
|
||||
return MlirModuleToString(module, show_debug_info);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::string ImportGraphDef(const std::string &proto,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
GraphDef graphdef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
|
||||
if (!s.ok()) {
|
||||
@ -84,13 +86,14 @@ std::string ImportGraphDef(const std::string &proto,
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
|
||||
return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
|
||||
status);
|
||||
}
|
||||
|
||||
std::string ImportFunction(const std::string &functiondef_proto,
|
||||
const std::string &functiondef_library_proto,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
FunctionDef functiondef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
|
||||
if (!s.ok()) {
|
||||
@ -105,15 +108,16 @@ std::string ImportFunction(const std::string &functiondef_proto,
|
||||
return "// error";
|
||||
}
|
||||
|
||||
const std::string &function_name = functiondef.signature().name();
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
s = flib_def.AddFunctionDef(functiondef);
|
||||
s = flib_def.AddFunctionDef(functiondef,
|
||||
flib_def.GetStackTraces(function_name));
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
const std::string &function_name = functiondef.signature().name();
|
||||
|
||||
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
|
||||
if (fdef == nullptr) {
|
||||
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
|
||||
@ -136,7 +140,8 @@ std::string ImportFunction(const std::string &functiondef_proto,
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
|
||||
return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
|
||||
status);
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelToMlir(
|
||||
|
@ -30,7 +30,8 @@ namespace tensorflow {
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
std::string ImportGraphDef(const std::string &proto,
|
||||
const std::string &pass_pipeline, TF_Status *status);
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_function.
|
||||
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
|
||||
@ -40,7 +41,8 @@ std::string ImportGraphDef(const std::string &proto,
|
||||
// around a Python binding to the MLIR module.
|
||||
std::string ImportFunction(const std::string &functiondef_proto,
|
||||
const std::string &functiondef_library_proto,
|
||||
const std::string &pass_pipeline, TF_Status *status);
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
|
@ -23,7 +23,9 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export('mlir.experimental.convert_graph_def')
|
||||
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||
def convert_graph_def(graph_def,
|
||||
pass_pipeline='tf-standard-pipeline',
|
||||
show_debug_info=False):
|
||||
"""Import a GraphDef and convert it to a textual MLIR module.
|
||||
|
||||
This API is only intended for inspecting the internals of TensorFlow and the
|
||||
@ -35,6 +37,7 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
|
||||
module, see MLIR documentation for the
|
||||
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
|
||||
show_debug_info: Whether to include locations in the emitted textual form.
|
||||
|
||||
Returns:
|
||||
A textual representation of the MLIR module corresponding to the graphdef.
|
||||
@ -44,11 +47,13 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||
MLIR.
|
||||
|
||||
"""
|
||||
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
|
||||
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline, show_debug_info)
|
||||
|
||||
|
||||
@tf_export('mlir.experimental.convert_function')
|
||||
def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
|
||||
def convert_function(concrete_function,
|
||||
pass_pipeline='tf-standard-pipeline',
|
||||
show_debug_info=False):
|
||||
"""Import a ConcreteFunction and convert it to a textual MLIR module.
|
||||
|
||||
This API is only intended for inspecting the internals of TensorFlow and the
|
||||
@ -75,6 +80,7 @@ def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
|
||||
pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
|
||||
module, see MLIR documentation for the
|
||||
[textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
|
||||
show_debug_info: Whether to include locations in the emitted textual form.
|
||||
|
||||
Returns:
|
||||
A textual representation of the MLIR module corresponding to the
|
||||
@ -85,4 +91,5 @@ def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
|
||||
to MLIR.
|
||||
|
||||
"""
|
||||
return pywrap_mlir.import_function(concrete_function, pass_pipeline)
|
||||
return pywrap_mlir.import_function(concrete_function, pass_pipeline,
|
||||
show_debug_info)
|
||||
|
@ -51,8 +51,9 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||
|
||||
concrete_function = sqr.get_concrete_function(
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
mlir_module = mlir.convert_function(concrete_function)
|
||||
mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
|
||||
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
|
||||
self.assertRegex(mlir_module, r'loc\(')
|
||||
|
||||
def testImportWithCall(self):
|
||||
|
||||
|
@ -22,22 +22,25 @@ limitations under the License.
|
||||
|
||||
PYBIND11_MODULE(_pywrap_mlir, m) {
|
||||
m.def("ImportGraphDef",
|
||||
[](const std::string &graphdef, const std::string &pass_pipeline) {
|
||||
[](const std::string &graphdef, const std::string &pass_pipeline,
|
||||
bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output =
|
||||
tensorflow::ImportGraphDef(graphdef, pass_pipeline, status.get());
|
||||
std::string output = tensorflow::ImportGraphDef(
|
||||
graphdef, pass_pipeline, show_debug_info, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ImportFunction", [](const std::string &functiondef,
|
||||
const std::string &functiondef_library,
|
||||
const std::string &pass_pipeline) {
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output = tensorflow::ImportFunction(
|
||||
functiondef, functiondef_library, pass_pipeline, status.get());
|
||||
functiondef, functiondef_library, pass_pipeline, show_debug_info,
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
@ -23,16 +23,17 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python._pywrap_mlir import *
|
||||
|
||||
|
||||
def import_graphdef(graphdef, pass_pipeline):
|
||||
def import_graphdef(graphdef, pass_pipeline, show_debug_info):
|
||||
return ImportGraphDef(
|
||||
str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'))
|
||||
str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'),
|
||||
show_debug_info)
|
||||
|
||||
|
||||
def import_function(concrete_function, pass_pipeline):
|
||||
def import_function(concrete_function, pass_pipeline, show_debug_info):
|
||||
return ImportFunction(
|
||||
str(concrete_function.function_def).encode('utf-8'),
|
||||
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'))
|
||||
pass_pipeline.encode('utf-8'), show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||
|
@ -2,10 +2,10 @@ path: "tensorflow.mlir.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "convert_function"
|
||||
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
argspec: "args=[\'concrete_function\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert_graph_def"
|
||||
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
argspec: "args=[\'graph_def\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,10 @@ path: "tensorflow.mlir.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "convert_function"
|
||||
argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
argspec: "args=[\'concrete_function\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert_graph_def"
|
||||
argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
|
||||
argspec: "args=[\'graph_def\', \'pass_pipeline\', \'show_debug_info\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user