Use flib of attached context.

Stacks not part of proto. Moved to TF2 only test and run with TF2_BEHAVIOR env set.

PiperOrigin-RevId: 351590975
Change-Id: If5aa883c2890e53f7feda54e7ccf05d77921cfa3
This commit is contained in:
Jacques Pienaar 2021-01-13 08:22:25 -08:00 committed by TensorFlower Gardener
parent 61bf442edc
commit 2ab8822125
6 changed files with 39 additions and 39 deletions

View File

@ -11,6 +11,8 @@ cc_library(
srcs = ["mlir.cc"],
hdrs = ["mlir.h"],
deps = [
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/mlir/tensorflow",
@ -41,6 +43,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime:core_cpu_base_no_ops",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/framework/function.h"
@ -91,9 +94,9 @@ std::string ImportGraphDef(const std::string &proto,
}
std::string ImportFunction(const std::string &functiondef_proto,
const std::string &functiondef_library_proto,
const std::string &pass_pipeline,
bool show_debug_info, TF_Status *status) {
bool show_debug_info, TFE_Context *tfe_context,
TF_Status *status) {
FunctionDef functiondef;
auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
if (!s.ok()) {
@ -101,23 +104,9 @@ std::string ImportFunction(const std::string &functiondef_proto,
return "// error";
}
FunctionDefLibrary fdef_lib;
s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
const std::string &function_name = functiondef.signature().name();
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
s = flib_def.AddFunctionDef(functiondef,
flib_def.GetStackTraces(function_name));
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
EagerContext *cpp_context = ContextFromInterface(unwrap(tfe_context));
FunctionLibraryDefinition &flib_def = *cpp_context->FuncLibDef();
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
if (fdef == nullptr) {
s = tensorflow::errors::NotFound("Cannot find function ", function_name);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
@ -34,15 +35,14 @@ std::string ImportGraphDef(const std::string &proto,
bool show_debug_info, TF_Status *status);
// Simple wrapper to support tf.mlir.experimental.convert_function.
// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
// convert to MLIR, and (optionally) optimize the module before returning it as
// a string.
// Load FunctionDef (binary or textual proto format), convert to MLIR, and
// (optionally) optimize the module before returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
std::string ImportFunction(const std::string &functiondef_proto,
const std::string &functiondef_library_proto,
const std::string &pass_pipeline,
bool show_debug_info, TF_Status *status);
bool show_debug_info, TFE_Context *context,
TF_Status *status);
// Load a SavedModel and return a textual MLIR string corresponding to it.
//

View File

@ -23,6 +23,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import logging_ops
from tensorflow.python.platform import test
@ -43,6 +44,7 @@ class MLIRGraphDefImportTest(test.TestCase):
class MLIRConcreteFunctionImportTest(test.TestCase):
@test_util.run_v2_only
def testImport(self):
@def_function.function
@ -55,6 +57,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
self.assertRegex(mlir_module, r'func @.*sqr.*\(')
self.assertRegex(mlir_module, r'loc\(')
@test_util.run_v2_only
def testImportWithCall(self):
@def_function.function
@ -71,6 +74,7 @@ class MLIRConcreteFunctionImportTest(test.TestCase):
self.assertRegex(mlir_module, r'func @.*caller.*\(')
self.assertRegex(mlir_module, r'func private @.*callee.*\(')
@test_util.run_v2_only
def testImportWithControlRet(self):
@def_function.function

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/compiler/mlir/python/mlir.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
@ -32,18 +33,19 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
return output;
});
m.def("ImportFunction", [](const std::string &functiondef,
const std::string &functiondef_library,
const std::string &pass_pipeline,
bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output = tensorflow::ImportFunction(
functiondef, functiondef_library, pass_pipeline, show_debug_info,
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ImportFunction",
[](const py::handle &context, const std::string &functiondef,
const std::string &pass_pipeline, bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
auto *ctxt = static_cast<TFE_Context *>(
PyCapsule_GetPointer(context.ptr(), nullptr));
if (!ctxt) throw py::error_already_set();
std::string output = tensorflow::ImportFunction(
functiondef, pass_pipeline, show_debug_info, ctxt, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalConvertSavedModelToMlir",
[](const std::string &saved_model_path,

View File

@ -20,6 +20,7 @@ from __future__ import print_function
# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python._pywrap_mlir import *
@ -30,10 +31,11 @@ def import_graphdef(graphdef, pass_pipeline, show_debug_info):
def import_function(concrete_function, pass_pipeline, show_debug_info):
return ImportFunction(
str(concrete_function.function_def).encode('utf-8'),
str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
pass_pipeline.encode('utf-8'), show_debug_info)
ctxt = context.context()
ctxt.ensure_initialized()
return ImportFunction(ctxt._handle,
str(concrete_function.function_def).encode('utf-8'),
pass_pipeline.encode('utf-8'), show_debug_info)
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,