Add an MLIR tracing implementation to the C unified API

This is plumbing just enough to pass all the unit-tests.
The conversion to the function library is quite inefficient, but it isn't
clear if we want to optimize this or just focus on TFRT moving forward.

PiperOrigin-RevId: 313356850
Change-Id: I83815317d4958786d0103168b5d88498f89511ed
This commit is contained in:
Mehdi Amini 2020-05-27 02:56:50 -07:00 committed by TensorFlower Gardener
parent 7738c1818e
commit f0ef163443
10 changed files with 632 additions and 4 deletions

View File

@ -216,6 +216,7 @@ tf_cuda_library(
],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [

View File

@ -144,6 +144,24 @@ cc_library(
],
)
cc_library(
name = "c_api_unified_internal",
hdrs = [
"c_api_unified_experimental_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
@ -514,6 +532,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -58,7 +58,7 @@ T* dyncast(S source) {
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
@ -101,7 +101,7 @@ class AbstractFunction {
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
@ -129,7 +129,7 @@ class AbstractOp {
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:

View File

@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
} // namespace
} // namespace tensorflow

View File

@ -788,6 +788,9 @@ cc_library(
name = "convert_type",
srcs = ["utils/convert_type.cc"],
hdrs = ["utils/convert_type.h"],
visibility = [
"//visibility:public",
],
deps = [
":tensorflow_types",
"//tensorflow/core:framework",

View File

@ -0,0 +1,55 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
"tfe_xla_copts",
)
package(
default_visibility = [":friends"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
packages = ["//tensorflow/..."],
)
tf_cuda_library(
name = "mlir_c_api",
srcs = [
"c_api_unified_experimental_mlir.cc",
],
copts = tf_copts() + tfe_xla_copts(),
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "mlir_c_api_registration",
srcs = ["c_api_unified_experimental_mlir_registration.cc"],
deps = [
":mlir_c_api",
"//tensorflow/c/eager:c_api_unified_internal",
],
alwayslink = 1,
)

View File

@ -0,0 +1,493 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <memory>
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
namespace mlir {
namespace TF {
using tensorflow::internal::AbstractFunction;
using tensorflow::internal::AbstractOp;
using tensorflow::internal::AbstractTensor;
using tensorflow::internal::dyncast;
using tensorflow::internal::ExecutionContext;
using tensorflow::internal::OutputList;
namespace {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
Type* type) {
Status s = tensorflow::ConvertDataType(dtype, builder, type);
if (s.ok()) *type = UnrankedTensorType::get(*type);
return s;
}
class MlirTensor : public AbstractTensor {
public:
explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {}
Value getValue() { return value_; }
static constexpr AbstractTensorKind kKind = kMlirTensor;
private:
Value value_;
};
class MlirAbstractOp : public AbstractOp {
public:
explicit MlirAbstractOp(MLIRContext* context)
: AbstractOp(kKind), context_(context) {}
void SetOpType(const char* op_type, TF_Status* s) override;
void SetAttrType(const char* attr_name, TF_DataType dtype,
TF_Status* s) override;
void SetOpName(const char* const op_name, TF_Status* s) override;
MLIRContext* GetContext() { return context_; }
Type AddRef(Type type, TF_Status* s);
OperationState* Create(ArrayRef<Value> operands, TF_Status* s);
static constexpr AbstractOpKind kKind = kMlirOp;
private:
MLIRContext* context_;
llvm::StringMap<Attribute> attrs_;
std::unique_ptr<OperationState> state_;
const char* op_name_ = nullptr;
};
// MlirFunction is a thin wrapper over a FuncOp.
class MlirFunction : public AbstractFunction {
public:
explicit MlirFunction(std::unique_ptr<MLIRContext> context,
OwningModuleRef module, FuncOp func)
: AbstractFunction(kKind),
context_(std::move(context)),
module_(std::move(module)),
func_(func) {}
TF_Function* GetTfFunction(TF_Status* s) override;
static constexpr AbstractFunctionKind kKind = kGraphFunc;
private:
std::unique_ptr<MLIRContext> context_;
OwningModuleRef module_;
FuncOp func_;
};
class MlirFunctionContext : public ExecutionContext {
public:
explicit MlirFunctionContext(const char* name)
: ExecutionContext(kKind),
context_(std::make_unique<MLIRContext>()),
builder_(context_.get()) {
// TODO(aminim) figure out the location story here
module_ = ModuleOp::create(builder_.getUnknownLoc());
func_ = FuncOp::create(builder_.getUnknownLoc(), name,
builder_.getFunctionType(llvm::None, llvm::None));
module_->push_back(func_);
builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
}
AbstractOp* CreateOperation() override {
return new MlirAbstractOp(context_.get());
}
void ExecuteOperation(AbstractOp* abstract_op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override;
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override;
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override;
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
s->status = tensorflow::errors::Unimplemented(
"Registering graph functions has not been implemented yet.");
}
static constexpr ExecutionContextKind kKind = kMlirContext;
private:
std::unique_ptr<MLIRContext> context_;
OpBuilder builder_;
FuncOp func_;
OwningModuleRef module_;
};
void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) {
if (state_) {
s->status = tensorflow::errors::FailedPrecondition(
"SetOpType called on already built op.");
return;
}
std::string name = "tf.";
name += op_type;
// TODO(aminim) figure out the location story here
state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
}
void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype,
TF_Status* s) {
if (!state_) {
s->status = tensorflow::errors::FailedPrecondition(
"op_type must be specified before specifying attrs.");
return;
}
Type mlir_type;
Builder builder(context_);
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
builder, &mlir_type);
if (!s->status.ok()) return;
attrs_[attr_name] = TypeAttr::get(mlir_type);
}
void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) {
// TODO(aminim): should we use a location?
if (op_name_) {
s->status = tensorflow::errors::FailedPrecondition(
"SetOpName called on already built op.");
return;
}
op_name_ = op_name;
}
Type MlirAbstractOp::AddRef(Type type, TF_Status* s) {
Type elt_type = getElementTypeOrSelf(type);
if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
s->status = tensorflow::errors::InvalidArgument(
"Requested reference to a reference type");
return nullptr;
}
elt_type = TensorFlowRefType::get(elt_type);
if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(tensor_type.getShape(), elt_type);
}
return UnrankedTensorType::get(elt_type);
}
OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
state_->operands = llvm::to_vector<4>(operands);
const tensorflow::OpDef* op_def;
auto node_name = state_->name.getStringRef().drop_front(
TensorFlowDialect::getDialectNamespace().size() + 1);
s->status =
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def);
if (!s->status.ok()) return nullptr;
Builder builder(context_);
// Process operands according to the op_def and infer derived attributes.
int current_operand = 0;
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
if (!input_arg.number_attr().empty()) {
// TODO(b/156122856): we don't support variadic operands.
s->status = tensorflow::errors::Unimplemented(
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
return nullptr;
} else if (!input_arg.type_list_attr().empty()) {
s->status = tensorflow::errors::InvalidArgument(
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
return nullptr;
}
if (current_operand >= operands.size()) {
s->status = tensorflow::errors::InvalidArgument("Missing operand for '",
input_arg.name(), "'");
return nullptr;
}
Type expected_type;
if (input_arg.type() != tensorflow::DT_INVALID) {
s->status =
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type);
if (!s->status.ok()) return nullptr;
if (input_arg.is_ref()) expected_type = AddRef(expected_type, s);
if (!s->status.ok()) return nullptr;
} else {
expected_type = operands[current_operand].getType();
}
if (!input_arg.type_attr().empty()) {
attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type);
}
++current_operand;
}
for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) {
int original_size = state_->types.size();
if (!output_arg.number_attr().empty()) {
// Same type repeated "repeats" times.
Attribute repeats_attr = attrs_[output_arg.number_attr()];
if (!repeats_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.number_attr(),
"' required for output list '", output_arg.name(), "'");
return nullptr;
}
if (!repeats_attr.isa<IntegerAttr>()) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.number_attr(),
"' required for output list '", output_arg.name(),
"' isn't an integer");
return nullptr;
}
int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
if (!output_arg.type_attr().empty()) {
// Same type repeated "repeats" times.
Attribute attr = attrs_[output_arg.type_attr()];
if (!attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.type_attr(),
"' required for output '", output_arg.name(), "'");
return nullptr;
}
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
if (!type_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.type_attr(), "' required for output '",
output_arg.name(), "' isn't a type attribute");
return nullptr;
}
for (int i = 0; i < repeats; ++i)
state_->types.push_back(type_attr.getType());
} else if (output_arg.type() != tensorflow::DT_INVALID) {
for (int i = 0; i < repeats; ++i) {
Type type;
s->status =
ConvertDataTypeToTensor(output_arg.type(), builder, &type);
if (!s->status.ok()) return nullptr;
state_->types.push_back(type);
}
} else {
s->status = tensorflow::errors::InvalidArgument(
"Missing type or type_attr field in ",
output_arg.ShortDebugString());
return nullptr;
}
} else if (!output_arg.type_attr().empty()) {
Attribute attr = attrs_[output_arg.type_attr()];
if (!attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.type_attr(),
"' required for output '", output_arg.name(), "'");
return nullptr;
}
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
if (!type_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.type_attr(), "' required for output '",
output_arg.name(), "' isn't a type attribute");
return nullptr;
}
state_->types.push_back(type_attr.getValue());
} else if (!output_arg.type_list_attr().empty()) {
// This is pointing to an attribute which is an array of types.
Attribute attr = attrs_[output_arg.type_list_attr()];
if (!attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.type_list_attr(),
"' required for output '", output_arg.name(), "'");
return nullptr;
}
ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
if (!array_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.type_list_attr(),
"' required for output '", output_arg.name(),
"' isn't an array attribute");
return nullptr;
}
for (Attribute attr : array_attr) {
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
if (!type_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Array Attribute '", output_arg.type_list_attr(),
"' required for output '", output_arg.name(),
"' has a non-Type element");
return nullptr;
}
state_->types.push_back(type_attr.getValue());
}
} else if (output_arg.type() != tensorflow::DT_INVALID) {
Type type;
Builder builder(context_);
s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type);
if (!s->status.ok()) return nullptr;
state_->types.push_back(type);
} else {
s->status = tensorflow::errors::InvalidArgument(
"No type fields in ", output_arg.ShortDebugString());
if (!s->status.ok()) return nullptr;
}
if (output_arg.is_ref()) {
// For all types that were added by this function call, make them refs.
for (Type& type : llvm::make_range(&state_->types[original_size],
state_->types.end())) {
type = AddRef(type, s);
if (!s->status.ok()) return nullptr;
}
}
}
return state_.get();
}
TF_Function* MlirFunction::GetTfFunction(TF_Status* s) {
PassManager pm(func_.getContext());
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
// In case of failure, the `diag_handler` converts MLIR errors emitted to
// the MLIRContext into a tensorflow::Status.
StatusScopedDiagnosticHandler diag_handler(func_.getContext());
LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
(void)result;
s->status = diag_handler.ConsumeStatus();
if (!s->status.ok()) return nullptr;
tensorflow::GraphExportConfig configs;
std::unique_ptr<TF_Function> tf_function(new TF_Function);
s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs,
&tf_function->fdef);
return tf_function.release();
}
void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op,
int num_inputs,
AbstractTensor* const* inputs,
OutputList* o, TF_Status* s) {
auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op);
if (mlir_op == nullptr) {
s->status = tensorflow::errors::InvalidArgument(
"Unable to cast AbstractOp to TF_GraphOp.");
return;
}
SmallVector<Value, 8> operands;
for (int i = 0; i < num_inputs; ++i) {
auto* operand = dyncast<MlirTensor>(inputs[i]);
if (!operand) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing eager tensors is not supported yet.");
return;
}
if (operand->getValue().getContext() != context_.get()) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing tensors from other context is not supported.");
return;
}
operands.push_back(operand->getValue());
}
OperationState* state = mlir_op->Create(operands, s);
if (!s->status.ok() || !state) return;
Operation* op = builder_.createOperation(*state);
int num_results = op->getNumResults();
o->outputs.clear();
o->outputs.reserve(num_results);
for (Value result : op->getResults())
o->outputs.push_back(new MlirTensor(result));
}
AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype,
TF_Status* s) {
Type type;
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
builder_, &type);
if (!s->status.ok()) return nullptr;
return new MlirTensor(func_.getBody().front().addArgument(type));
}
AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs,
TF_Status* s) {
Block& body = func_.getBody().front();
SmallVector<Value, 8> ret_operands;
for (AbstractTensor* output : outputs->outputs) {
auto* operand = dyncast<MlirTensor>(output);
if (!operand) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing eager tensors is not supported yet.");
return nullptr;
}
if (operand->getValue().getContext() != context_.get()) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing tensors from other context is not supported.");
return nullptr;
}
ret_operands.push_back(operand->getValue());
}
builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
auto arg_types = llvm::to_vector<8>(body.getArgumentTypes());
auto result_types =
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
return new MlirFunction(std::move(context_), std::move(module_), func_);
}
extern "C" {
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
RegisterDialects();
return new MlirFunctionContext(fn_name);
}
}
} // end anonymous namespace
} // end namespace TF
} // end namespace mlir

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
using tensorflow::internal::ExecutionContext;
extern "C" {
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s);
}
namespace {
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("mlir", MlirTracingFactory);
return true;
}();
} // namespace

View File

@ -782,4 +782,22 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
return graphdef;
}
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
mlir::FuncOp func, const GraphExportConfig& configs,
FunctionDef* function_def) {
Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf");
FunctionDefLibrary flib;
TF_RETURN_IF_ERROR(
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
for (auto& func_def : flib.function()) {
if (func_def.signature().name() == func.getName()) {
*function_def = func_def;
return Status::OK();
}
}
return errors::InvalidArgument(
"Function couldn't be found in the FunctionDefLibrary after converting "
"from MLIR");
}
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
@ -50,6 +51,12 @@ stream_executor::port::Status ConvertMlirToGraph(
stream_executor::port::Status ConvertMlirToGraph(
mlir::ModuleOp module, const GraphExportConfig& configs,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
// Converts an MLIR function and adds it to a FunctionLibraryDefinition.
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
mlir::FuncOp func, const GraphExportConfig& configs,
FunctionDef* function_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_