diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index e2781afc3e5..12021a294e8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -216,6 +216,7 @@ tf_cuda_library( ], visibility = [ "//tensorflow/c:__subpackages__", + "//tensorflow/compiler/mlir/tensorflow/c:__subpackages__", ], deps = select({ "//tensorflow:android": [ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index eb3035cc3d7..b8429646960 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -144,6 +144,24 @@ cc_library( ], ) +cc_library( + name = "c_api_unified_internal", + hdrs = [ + "c_api_unified_experimental_internal.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status", + "//tensorflow/core/platform:casts", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "tensor_handle_interface", hdrs = ["tensor_handle_interface.h"], @@ -514,6 +532,7 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index 49212a230ee..8fc696f0f2f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -58,7 +58,7 @@ T* dyncast(S source) { // GraphContext and vice-versa). class AbstractTensor { protected: - enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor }; + enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor }; explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {} public: @@ -101,7 +101,7 @@ class AbstractFunction { // on a given context, with the same or different input tensors. class AbstractOp { protected: - enum AbstractOpKind { kGraphOp, kEagerOp }; + enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp }; explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {} public: @@ -129,7 +129,7 @@ class AbstractOp { // eager implementation or to a graph implementation. struct ExecutionContext { protected: - enum ExecutionContextKind { kGraphContext, kEagerContext }; + enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext }; explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {} public: diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 9776b4d13ed..24d170f2f99 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { TF_DeleteExecutionContext(eager_execution_ctx); } -INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef")); +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, + ::testing::Values("graphdef", "mlir")); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index de0af94f0cb..5110ea7fbf5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -788,6 +788,9 @@ cc_library( name = "convert_type", srcs = ["utils/convert_type.cc"], hdrs = ["utils/convert_type.h"], + visibility = [ + "//visibility:public", + ], deps = [ ":tensorflow_types", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD new file mode 100644 index 00000000000..3a503685fc6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -0,0 +1,55 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_cuda_library", + "tfe_xla_copts", +) + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + packages = ["//tensorflow/..."], +) + +tf_cuda_library( + name = "mlir_c_api", + srcs = [ + "c_api_unified_experimental_mlir.cc", + ], + copts = tf_copts() + tfe_xla_copts(), + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:casts", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "mlir_c_api_registration", + srcs = ["c_api_unified_experimental_mlir_registration.cc"], + deps = [ + ":mlir_c_api", + "//tensorflow/c/eager:c_api_unified_internal", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc new file mode 100644 index 00000000000..0e8b7fedd9b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -0,0 +1,493 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" + +namespace mlir { +namespace TF { +using tensorflow::internal::AbstractFunction; +using tensorflow::internal::AbstractOp; +using tensorflow::internal::AbstractTensor; +using tensorflow::internal::dyncast; +using tensorflow::internal::ExecutionContext; +using tensorflow::internal::OutputList; + +namespace { + +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + +Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, + Type* type) { + Status s = tensorflow::ConvertDataType(dtype, builder, type); + if (s.ok()) *type = UnrankedTensorType::get(*type); + return s; +} + +class MlirTensor : public AbstractTensor { + public: + explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {} + + Value getValue() { return value_; } + + static constexpr AbstractTensorKind kKind = kMlirTensor; + + private: + Value value_; +}; + +class MlirAbstractOp : public AbstractOp { + public: + explicit MlirAbstractOp(MLIRContext* context) + : AbstractOp(kKind), context_(context) {} + + void SetOpType(const char* op_type, TF_Status* s) override; + + void SetAttrType(const char* attr_name, TF_DataType dtype, + TF_Status* s) override; + + void SetOpName(const char* const op_name, TF_Status* s) override; + + MLIRContext* GetContext() { return context_; } + + Type AddRef(Type type, TF_Status* s); + + OperationState* Create(ArrayRef operands, TF_Status* s); + + static constexpr AbstractOpKind kKind = kMlirOp; + + private: + MLIRContext* context_; + llvm::StringMap attrs_; + std::unique_ptr state_; + const char* op_name_ = nullptr; +}; + +// MlirFunction is a thin wrapper over a FuncOp. +class MlirFunction : public AbstractFunction { + public: + explicit MlirFunction(std::unique_ptr context, + OwningModuleRef module, FuncOp func) + : AbstractFunction(kKind), + context_(std::move(context)), + module_(std::move(module)), + func_(func) {} + + TF_Function* GetTfFunction(TF_Status* s) override; + + static constexpr AbstractFunctionKind kKind = kGraphFunc; + + private: + std::unique_ptr context_; + OwningModuleRef module_; + FuncOp func_; +}; + +class MlirFunctionContext : public ExecutionContext { + public: + explicit MlirFunctionContext(const char* name) + : ExecutionContext(kKind), + context_(std::make_unique()), + builder_(context_.get()) { + // TODO(aminim) figure out the location story here + module_ = ModuleOp::create(builder_.getUnknownLoc()); + func_ = FuncOp::create(builder_.getUnknownLoc(), name, + builder_.getFunctionType(llvm::None, llvm::None)); + module_->push_back(func_); + builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock()); + } + + AbstractOp* CreateOperation() override { + return new MlirAbstractOp(context_.get()); + } + + void ExecuteOperation(AbstractOp* abstract_op, int num_inputs, + AbstractTensor* const* inputs, OutputList* o, + TF_Status* s) override; + + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override; + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override; + + void RegisterFunction(AbstractFunction* func, TF_Status* s) override { + s->status = tensorflow::errors::Unimplemented( + "Registering graph functions has not been implemented yet."); + } + + static constexpr ExecutionContextKind kKind = kMlirContext; + + private: + std::unique_ptr context_; + OpBuilder builder_; + FuncOp func_; + OwningModuleRef module_; +}; + +void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) { + if (state_) { + s->status = tensorflow::errors::FailedPrecondition( + "SetOpType called on already built op."); + return; + } + std::string name = "tf."; + name += op_type; + // TODO(aminim) figure out the location story here + state_ = std::make_unique(UnknownLoc::get(context_), name); +} + +void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype, + TF_Status* s) { + if (!state_) { + s->status = tensorflow::errors::FailedPrecondition( + "op_type must be specified before specifying attrs."); + return; + } + Type mlir_type; + Builder builder(context_); + s->status = ConvertDataTypeToTensor(static_cast(dtype), + builder, &mlir_type); + if (!s->status.ok()) return; + attrs_[attr_name] = TypeAttr::get(mlir_type); +} + +void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) { + // TODO(aminim): should we use a location? + if (op_name_) { + s->status = tensorflow::errors::FailedPrecondition( + "SetOpName called on already built op."); + return; + } + op_name_ = op_name; +} + +Type MlirAbstractOp::AddRef(Type type, TF_Status* s) { + Type elt_type = getElementTypeOrSelf(type); + if (elt_type.isa()) { + s->status = tensorflow::errors::InvalidArgument( + "Requested reference to a reference type"); + return nullptr; + } + elt_type = TensorFlowRefType::get(elt_type); + if (RankedTensorType tensor_type = type.dyn_cast()) { + return RankedTensorType::get(tensor_type.getShape(), elt_type); + } + return UnrankedTensorType::get(elt_type); +} + +OperationState* MlirAbstractOp::Create(ArrayRef operands, TF_Status* s) { + state_->operands = llvm::to_vector<4>(operands); + const tensorflow::OpDef* op_def; + auto node_name = state_->name.getStringRef().drop_front( + TensorFlowDialect::getDialectNamespace().size() + 1); + s->status = + tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def); + if (!s->status.ok()) return nullptr; + Builder builder(context_); + // Process operands according to the op_def and infer derived attributes. + int current_operand = 0; + for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) { + if (!input_arg.number_attr().empty()) { + // TODO(b/156122856): we don't support variadic operands. + s->status = tensorflow::errors::Unimplemented( + "Unsupported 'number_attr' for '", input_arg.number_attr(), "'"); + return nullptr; + } else if (!input_arg.type_list_attr().empty()) { + s->status = tensorflow::errors::InvalidArgument( + "Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'"); + return nullptr; + } + if (current_operand >= operands.size()) { + s->status = tensorflow::errors::InvalidArgument("Missing operand for '", + input_arg.name(), "'"); + return nullptr; + } + Type expected_type; + if (input_arg.type() != tensorflow::DT_INVALID) { + s->status = + ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type); + if (!s->status.ok()) return nullptr; + if (input_arg.is_ref()) expected_type = AddRef(expected_type, s); + if (!s->status.ok()) return nullptr; + } else { + expected_type = operands[current_operand].getType(); + } + if (!input_arg.type_attr().empty()) { + attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type); + } + ++current_operand; + } + + for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) { + int original_size = state_->types.size(); + if (!output_arg.number_attr().empty()) { + // Same type repeated "repeats" times. + Attribute repeats_attr = attrs_[output_arg.number_attr()]; + if (!repeats_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.number_attr(), + "' required for output list '", output_arg.name(), "'"); + return nullptr; + } + if (!repeats_attr.isa()) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.number_attr(), + "' required for output list '", output_arg.name(), + "' isn't an integer"); + return nullptr; + } + int64_t repeats = repeats_attr.cast().getInt(); + + if (!output_arg.type_attr().empty()) { + // Same type repeated "repeats" times. + Attribute attr = attrs_[output_arg.type_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + TypeAttr type_attr = attr.dyn_cast(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_attr(), "' required for output '", + output_arg.name(), "' isn't a type attribute"); + return nullptr; + } + for (int i = 0; i < repeats; ++i) + state_->types.push_back(type_attr.getType()); + } else if (output_arg.type() != tensorflow::DT_INVALID) { + for (int i = 0; i < repeats; ++i) { + Type type; + s->status = + ConvertDataTypeToTensor(output_arg.type(), builder, &type); + if (!s->status.ok()) return nullptr; + state_->types.push_back(type); + } + } else { + s->status = tensorflow::errors::InvalidArgument( + "Missing type or type_attr field in ", + output_arg.ShortDebugString()); + return nullptr; + } + } else if (!output_arg.type_attr().empty()) { + Attribute attr = attrs_[output_arg.type_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + TypeAttr type_attr = attr.dyn_cast(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_attr(), "' required for output '", + output_arg.name(), "' isn't a type attribute"); + return nullptr; + } + state_->types.push_back(type_attr.getValue()); + } else if (!output_arg.type_list_attr().empty()) { + // This is pointing to an attribute which is an array of types. + Attribute attr = attrs_[output_arg.type_list_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + ArrayAttr array_attr = attr.dyn_cast(); + if (!array_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' isn't an array attribute"); + return nullptr; + } + for (Attribute attr : array_attr) { + TypeAttr type_attr = attr.dyn_cast(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Array Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' has a non-Type element"); + return nullptr; + } + state_->types.push_back(type_attr.getValue()); + } + } else if (output_arg.type() != tensorflow::DT_INVALID) { + Type type; + Builder builder(context_); + s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type); + if (!s->status.ok()) return nullptr; + state_->types.push_back(type); + } else { + s->status = tensorflow::errors::InvalidArgument( + "No type fields in ", output_arg.ShortDebugString()); + if (!s->status.ok()) return nullptr; + } + if (output_arg.is_ref()) { + // For all types that were added by this function call, make them refs. + for (Type& type : llvm::make_range(&state_->types[original_size], + state_->types.end())) { + type = AddRef(type, s); + if (!s->status.ok()) return nullptr; + } + } + } + return state_.get(); +} + +TF_Function* MlirFunction::GetTfFunction(TF_Status* s) { + PassManager pm(func_.getContext()); + pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); + pm.addNestedPass(CreateBreakUpIslandsPass()); + + // In case of failure, the `diag_handler` converts MLIR errors emitted to + // the MLIRContext into a tensorflow::Status. + StatusScopedDiagnosticHandler diag_handler(func_.getContext()); + LogicalResult result = pm.run(func_.getParentOfType()); + (void)result; + s->status = diag_handler.ConsumeStatus(); + if (!s->status.ok()) return nullptr; + + tensorflow::GraphExportConfig configs; + std::unique_ptr tf_function(new TF_Function); + s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs, + &tf_function->fdef); + return tf_function.release(); +} + +void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op, + int num_inputs, + AbstractTensor* const* inputs, + OutputList* o, TF_Status* s) { + auto* mlir_op = dyncast(abstract_op); + if (mlir_op == nullptr) { + s->status = tensorflow::errors::InvalidArgument( + "Unable to cast AbstractOp to TF_GraphOp."); + return; + } + SmallVector operands; + for (int i = 0; i < num_inputs; ++i) { + auto* operand = dyncast(inputs[i]); + if (!operand) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing eager tensors is not supported yet."); + return; + } + if (operand->getValue().getContext() != context_.get()) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing tensors from other context is not supported."); + return; + } + operands.push_back(operand->getValue()); + } + OperationState* state = mlir_op->Create(operands, s); + if (!s->status.ok() || !state) return; + Operation* op = builder_.createOperation(*state); + int num_results = op->getNumResults(); + o->outputs.clear(); + o->outputs.reserve(num_results); + for (Value result : op->getResults()) + o->outputs.push_back(new MlirTensor(result)); +} + +AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype, + TF_Status* s) { + Type type; + s->status = ConvertDataTypeToTensor(static_cast(dtype), + builder_, &type); + if (!s->status.ok()) return nullptr; + return new MlirTensor(func_.getBody().front().addArgument(type)); +} + +AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs, + TF_Status* s) { + Block& body = func_.getBody().front(); + SmallVector ret_operands; + for (AbstractTensor* output : outputs->outputs) { + auto* operand = dyncast(output); + if (!operand) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing eager tensors is not supported yet."); + return nullptr; + } + if (operand->getValue().getContext() != context_.get()) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing tensors from other context is not supported."); + return nullptr; + } + ret_operands.push_back(operand->getValue()); + } + builder_.create(func_.getLoc(), ret_operands); + + auto arg_types = llvm::to_vector<8>(body.getArgumentTypes()); + auto result_types = + llvm::to_vector<8>(body.getTerminator()->getOperandTypes()); + func_.setType(FunctionType::get(arg_types, result_types, func_.getContext())); + return new MlirFunction(std::move(context_), std::move(module_), func_); +} + +extern "C" { +ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { + RegisterDialects(); + return new MlirFunctionContext(fn_name); +} +} + +} // end anonymous namespace +} // end namespace TF +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc new file mode 100644 index 00000000000..778f4b777a3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" + +using tensorflow::internal::ExecutionContext; + +extern "C" { +ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s); +} + +namespace { +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("mlir", MlirTracingFactory); + return true; +}(); + +} // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 75fcede8fbb..2bf55922d4b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -782,4 +782,22 @@ StatusOr> ConvertMlirToGraphdef( return graphdef; } +stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def) { + Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf"); + FunctionDefLibrary flib; + TF_RETURN_IF_ERROR( + Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); + for (auto& func_def : flib.function()) { + if (func_def.signature().name() == func.getName()) { + *function_def = func_def; + return Status::OK(); + } + } + return errors::InvalidArgument( + "Function couldn't be found in the FunctionDefLibrary after converting " + "from MLIR"); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index 2d522f6031e..a5aebd16146 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -50,6 +51,12 @@ stream_executor::port::Status ConvertMlirToGraph( stream_executor::port::Status ConvertMlirToGraph( mlir::ModuleOp module, const GraphExportConfig& configs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def); + +// Converts an MLIR function and adds it to a FunctionLibraryDefinition. +stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_