Use flib of attached context.
Stacks not part of proto. Moved to TF2 only test and run with TF2_BEHAVIOR env set. PiperOrigin-RevId: 351590975 Change-Id: If5aa883c2890e53f7feda54e7ccf05d77921cfa3
This commit is contained in:
parent
61bf442edc
commit
2ab8822125
@ -11,6 +11,8 @@ cc_library(
|
||||
srcs = ["mlir.cc"],
|
||||
hdrs = ["mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
@ -41,6 +43,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime:core_cpu_base_no_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
||||
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
@ -32,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -91,9 +94,9 @@ std::string ImportGraphDef(const std::string &proto,
|
||||
}
|
||||
|
||||
std::string ImportFunction(const std::string &functiondef_proto,
|
||||
const std::string &functiondef_library_proto,
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
bool show_debug_info, TFE_Context *tfe_context,
|
||||
TF_Status *status) {
|
||||
FunctionDef functiondef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
|
||||
if (!s.ok()) {
|
||||
@ -101,23 +104,9 @@ std::string ImportFunction(const std::string &functiondef_proto,
|
||||
return "// error";
|
||||
}
|
||||
|
||||
FunctionDefLibrary fdef_lib;
|
||||
s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
const std::string &function_name = functiondef.signature().name();
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
s = flib_def.AddFunctionDef(functiondef,
|
||||
flib_def.GetStackTraces(function_name));
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
EagerContext *cpp_context = ContextFromInterface(unwrap(tfe_context));
|
||||
FunctionLibraryDefinition &flib_def = *cpp_context->FuncLibDef();
|
||||
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
|
||||
if (fdef == nullptr) {
|
||||
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,15 +35,14 @@ std::string ImportGraphDef(const std::string &proto,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_function.
|
||||
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
|
||||
// convert to MLIR, and (optionally) optimize the module before returning it as
|
||||
// a string.
|
||||
// Load FunctionDef (binary or textual proto format), convert to MLIR, and
|
||||
// (optionally) optimize the module before returning it as a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
std::string ImportFunction(const std::string &functiondef_proto,
|
||||
const std::string &functiondef_library_proto,
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
bool show_debug_info, TFE_Context *context,
|
||||
TF_Status *status);
|
||||
|
||||
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
|
||||
@ -23,6 +23,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -43,6 +44,7 @@ class MLIRGraphDefImportTest(test.TestCase):
|
||||
|
||||
class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testImport(self):
|
||||
|
||||
@def_function.function
|
||||
@ -55,6 +57,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
|
||||
self.assertRegex(mlir_module, r'loc\(')
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testImportWithCall(self):
|
||||
|
||||
@def_function.function
|
||||
@ -71,6 +74,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
|
||||
self.assertRegex(mlir_module, r'func @.*caller.*\(')
|
||||
self.assertRegex(mlir_module, r'func private @.*callee.*\(')
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testImportWithControlRet(self):
|
||||
|
||||
@def_function.function
|
||||
|
||||
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/compiler/mlir/python/mlir.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
@ -32,18 +33,19 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ImportFunction", [](const std::string &functiondef,
|
||||
const std::string &functiondef_library,
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output = tensorflow::ImportFunction(
|
||||
functiondef, functiondef_library, pass_pipeline, show_debug_info,
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
m.def("ImportFunction",
|
||||
[](const py::handle &context, const std::string &functiondef,
|
||||
const std::string &pass_pipeline, bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
auto *ctxt = static_cast<TFE_Context *>(
|
||||
PyCapsule_GetPointer(context.ptr(), nullptr));
|
||||
if (!ctxt) throw py::error_already_set();
|
||||
std::string output = tensorflow::ImportFunction(
|
||||
functiondef, pass_pipeline, show_debug_info, ctxt, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelToMlir",
|
||||
[](const std::string &saved_model_path,
|
||||
|
||||
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python._pywrap_mlir import *
|
||||
|
||||
|
||||
@ -30,10 +31,11 @@ def import_graphdef(graphdef, pass_pipeline, show_debug_info):
|
||||
|
||||
|
||||
def import_function(concrete_function, pass_pipeline, show_debug_info):
|
||||
return ImportFunction(
|
||||
str(concrete_function.function_def).encode('utf-8'),
|
||||
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'), show_debug_info)
|
||||
ctxt = context.context()
|
||||
ctxt.ensure_initialized()
|
||||
return ImportFunction(ctxt._handle,
|
||||
str(concrete_function.function_def).encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'), show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user