Add an MLIR tracing implementation to the C unified API
This is plumbing just enough to pass all the unit-tests. The conversion to the function library is quite inefficient, but it isn't clear if we want to optimize this or just focus on TFRT moving forward. PiperOrigin-RevId: 313356850 Change-Id: I83815317d4958786d0103168b5d88498f89511ed
This commit is contained in:
parent
7738c1818e
commit
f0ef163443
|
@ -216,6 +216,7 @@ tf_cuda_library(
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/c:__subpackages__",
|
"//tensorflow/c:__subpackages__",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = select({
|
deps = select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
|
|
|
@ -144,6 +144,24 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "c_api_unified_internal",
|
||||||
|
hdrs = [
|
||||||
|
"c_api_unified_experimental_internal.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
":c_api_experimental",
|
||||||
|
"//tensorflow/c:c_api_internal",
|
||||||
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tensor_handle_interface",
|
name = "tensor_handle_interface",
|
||||||
hdrs = ["tensor_handle_interface.h"],
|
hdrs = ["tensor_handle_interface.h"],
|
||||||
|
@ -514,6 +532,7 @@ tf_cuda_cc_test(
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_test_util",
|
"//tensorflow/c:c_test_util",
|
||||||
"//tensorflow/cc/profiler",
|
"//tensorflow/cc/profiler",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
|
|
@ -58,7 +58,7 @@ T* dyncast(S source) {
|
||||||
// GraphContext and vice-versa).
|
// GraphContext and vice-versa).
|
||||||
class AbstractTensor {
|
class AbstractTensor {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
|
||||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -101,7 +101,7 @@ class AbstractFunction {
|
||||||
// on a given context, with the same or different input tensors.
|
// on a given context, with the same or different input tensors.
|
||||||
class AbstractOp {
|
class AbstractOp {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
|
||||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -129,7 +129,7 @@ class AbstractOp {
|
||||||
// eager implementation or to a graph implementation.
|
// eager implementation or to a graph implementation.
|
||||||
struct ExecutionContext {
|
struct ExecutionContext {
|
||||||
protected:
|
protected:
|
||||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
|
||||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
|
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||||
|
::testing::Values("graphdef", "mlir"));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -788,6 +788,9 @@ cc_library(
|
||||||
name = "convert_type",
|
name = "convert_type",
|
||||||
srcs = ["utils/convert_type.cc"],
|
srcs = ["utils/convert_type.cc"],
|
||||||
hdrs = ["utils/convert_type.h"],
|
hdrs = ["utils/convert_type.h"],
|
||||||
|
visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_copts",
|
||||||
|
"tf_cuda_library",
|
||||||
|
"tfe_xla_copts",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [":friends"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
package_group(
|
||||||
|
name = "friends",
|
||||||
|
packages = ["//tensorflow/..."],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "mlir_c_api",
|
||||||
|
srcs = [
|
||||||
|
"c_api_unified_experimental_mlir.cc",
|
||||||
|
],
|
||||||
|
copts = tf_copts() + tfe_xla_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/c:tf_status_internal",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
|
"//tensorflow/c/eager:c_api_unified_internal",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_c_api_registration",
|
||||||
|
srcs = ["c_api_unified_experimental_mlir_registration.cc"],
|
||||||
|
deps = [
|
||||||
|
":mlir_c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_unified_internal",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
|
@ -0,0 +1,493 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/ADT/iterator_range.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Location.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/c/tf_status_internal.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
using tensorflow::internal::AbstractFunction;
|
||||||
|
using tensorflow::internal::AbstractOp;
|
||||||
|
using tensorflow::internal::AbstractTensor;
|
||||||
|
using tensorflow::internal::dyncast;
|
||||||
|
using tensorflow::internal::ExecutionContext;
|
||||||
|
using tensorflow::internal::OutputList;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static void RegisterDialects() {
|
||||||
|
static bool init_once = []() {
|
||||||
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||||
|
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||||
|
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
(void)init_once;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
|
||||||
|
Type* type) {
|
||||||
|
Status s = tensorflow::ConvertDataType(dtype, builder, type);
|
||||||
|
if (s.ok()) *type = UnrankedTensorType::get(*type);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
class MlirTensor : public AbstractTensor {
|
||||||
|
public:
|
||||||
|
explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {}
|
||||||
|
|
||||||
|
Value getValue() { return value_; }
|
||||||
|
|
||||||
|
static constexpr AbstractTensorKind kKind = kMlirTensor;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Value value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirAbstractOp : public AbstractOp {
|
||||||
|
public:
|
||||||
|
explicit MlirAbstractOp(MLIRContext* context)
|
||||||
|
: AbstractOp(kKind), context_(context) {}
|
||||||
|
|
||||||
|
void SetOpType(const char* op_type, TF_Status* s) override;
|
||||||
|
|
||||||
|
void SetAttrType(const char* attr_name, TF_DataType dtype,
|
||||||
|
TF_Status* s) override;
|
||||||
|
|
||||||
|
void SetOpName(const char* const op_name, TF_Status* s) override;
|
||||||
|
|
||||||
|
MLIRContext* GetContext() { return context_; }
|
||||||
|
|
||||||
|
Type AddRef(Type type, TF_Status* s);
|
||||||
|
|
||||||
|
OperationState* Create(ArrayRef<Value> operands, TF_Status* s);
|
||||||
|
|
||||||
|
static constexpr AbstractOpKind kKind = kMlirOp;
|
||||||
|
|
||||||
|
private:
|
||||||
|
MLIRContext* context_;
|
||||||
|
llvm::StringMap<Attribute> attrs_;
|
||||||
|
std::unique_ptr<OperationState> state_;
|
||||||
|
const char* op_name_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// MlirFunction is a thin wrapper over a FuncOp.
|
||||||
|
class MlirFunction : public AbstractFunction {
|
||||||
|
public:
|
||||||
|
explicit MlirFunction(std::unique_ptr<MLIRContext> context,
|
||||||
|
OwningModuleRef module, FuncOp func)
|
||||||
|
: AbstractFunction(kKind),
|
||||||
|
context_(std::move(context)),
|
||||||
|
module_(std::move(module)),
|
||||||
|
func_(func) {}
|
||||||
|
|
||||||
|
TF_Function* GetTfFunction(TF_Status* s) override;
|
||||||
|
|
||||||
|
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<MLIRContext> context_;
|
||||||
|
OwningModuleRef module_;
|
||||||
|
FuncOp func_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirFunctionContext : public ExecutionContext {
|
||||||
|
public:
|
||||||
|
explicit MlirFunctionContext(const char* name)
|
||||||
|
: ExecutionContext(kKind),
|
||||||
|
context_(std::make_unique<MLIRContext>()),
|
||||||
|
builder_(context_.get()) {
|
||||||
|
// TODO(aminim) figure out the location story here
|
||||||
|
module_ = ModuleOp::create(builder_.getUnknownLoc());
|
||||||
|
func_ = FuncOp::create(builder_.getUnknownLoc(), name,
|
||||||
|
builder_.getFunctionType(llvm::None, llvm::None));
|
||||||
|
module_->push_back(func_);
|
||||||
|
builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractOp* CreateOperation() override {
|
||||||
|
return new MlirAbstractOp(context_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecuteOperation(AbstractOp* abstract_op, int num_inputs,
|
||||||
|
AbstractTensor* const* inputs, OutputList* o,
|
||||||
|
TF_Status* s) override;
|
||||||
|
|
||||||
|
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override;
|
||||||
|
|
||||||
|
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override;
|
||||||
|
|
||||||
|
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||||
|
s->status = tensorflow::errors::Unimplemented(
|
||||||
|
"Registering graph functions has not been implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr ExecutionContextKind kKind = kMlirContext;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<MLIRContext> context_;
|
||||||
|
OpBuilder builder_;
|
||||||
|
FuncOp func_;
|
||||||
|
OwningModuleRef module_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) {
|
||||||
|
if (state_) {
|
||||||
|
s->status = tensorflow::errors::FailedPrecondition(
|
||||||
|
"SetOpType called on already built op.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string name = "tf.";
|
||||||
|
name += op_type;
|
||||||
|
// TODO(aminim) figure out the location story here
|
||||||
|
state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype,
|
||||||
|
TF_Status* s) {
|
||||||
|
if (!state_) {
|
||||||
|
s->status = tensorflow::errors::FailedPrecondition(
|
||||||
|
"op_type must be specified before specifying attrs.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Type mlir_type;
|
||||||
|
Builder builder(context_);
|
||||||
|
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
|
||||||
|
builder, &mlir_type);
|
||||||
|
if (!s->status.ok()) return;
|
||||||
|
attrs_[attr_name] = TypeAttr::get(mlir_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) {
|
||||||
|
// TODO(aminim): should we use a location?
|
||||||
|
if (op_name_) {
|
||||||
|
s->status = tensorflow::errors::FailedPrecondition(
|
||||||
|
"SetOpName called on already built op.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
op_name_ = op_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type MlirAbstractOp::AddRef(Type type, TF_Status* s) {
|
||||||
|
Type elt_type = getElementTypeOrSelf(type);
|
||||||
|
if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Requested reference to a reference type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
elt_type = TensorFlowRefType::get(elt_type);
|
||||||
|
if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
return RankedTensorType::get(tensor_type.getShape(), elt_type);
|
||||||
|
}
|
||||||
|
return UnrankedTensorType::get(elt_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
|
||||||
|
state_->operands = llvm::to_vector<4>(operands);
|
||||||
|
const tensorflow::OpDef* op_def;
|
||||||
|
auto node_name = state_->name.getStringRef().drop_front(
|
||||||
|
TensorFlowDialect::getDialectNamespace().size() + 1);
|
||||||
|
s->status =
|
||||||
|
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
Builder builder(context_);
|
||||||
|
// Process operands according to the op_def and infer derived attributes.
|
||||||
|
int current_operand = 0;
|
||||||
|
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
|
||||||
|
if (!input_arg.number_attr().empty()) {
|
||||||
|
// TODO(b/156122856): we don't support variadic operands.
|
||||||
|
s->status = tensorflow::errors::Unimplemented(
|
||||||
|
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
|
||||||
|
return nullptr;
|
||||||
|
} else if (!input_arg.type_list_attr().empty()) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (current_operand >= operands.size()) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument("Missing operand for '",
|
||||||
|
input_arg.name(), "'");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Type expected_type;
|
||||||
|
if (input_arg.type() != tensorflow::DT_INVALID) {
|
||||||
|
s->status =
|
||||||
|
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
if (input_arg.is_ref()) expected_type = AddRef(expected_type, s);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
} else {
|
||||||
|
expected_type = operands[current_operand].getType();
|
||||||
|
}
|
||||||
|
if (!input_arg.type_attr().empty()) {
|
||||||
|
attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type);
|
||||||
|
}
|
||||||
|
++current_operand;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) {
|
||||||
|
int original_size = state_->types.size();
|
||||||
|
if (!output_arg.number_attr().empty()) {
|
||||||
|
// Same type repeated "repeats" times.
|
||||||
|
Attribute repeats_attr = attrs_[output_arg.number_attr()];
|
||||||
|
if (!repeats_attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Missing attribute '", output_arg.number_attr(),
|
||||||
|
"' required for output list '", output_arg.name(), "'");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!repeats_attr.isa<IntegerAttr>()) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Attribute '", output_arg.number_attr(),
|
||||||
|
"' required for output list '", output_arg.name(),
|
||||||
|
"' isn't an integer");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
if (!output_arg.type_attr().empty()) {
|
||||||
|
// Same type repeated "repeats" times.
|
||||||
|
Attribute attr = attrs_[output_arg.type_attr()];
|
||||||
|
if (!attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Missing attribute '", output_arg.type_attr(),
|
||||||
|
"' required for output '", output_arg.name(), "'");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||||
|
if (!type_attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Attribute '", output_arg.type_attr(), "' required for output '",
|
||||||
|
output_arg.name(), "' isn't a type attribute");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < repeats; ++i)
|
||||||
|
state_->types.push_back(type_attr.getType());
|
||||||
|
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
||||||
|
for (int i = 0; i < repeats; ++i) {
|
||||||
|
Type type;
|
||||||
|
s->status =
|
||||||
|
ConvertDataTypeToTensor(output_arg.type(), builder, &type);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
state_->types.push_back(type);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Missing type or type_attr field in ",
|
||||||
|
output_arg.ShortDebugString());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} else if (!output_arg.type_attr().empty()) {
|
||||||
|
Attribute attr = attrs_[output_arg.type_attr()];
|
||||||
|
if (!attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Missing attribute '", output_arg.type_attr(),
|
||||||
|
"' required for output '", output_arg.name(), "'");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||||
|
if (!type_attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Attribute '", output_arg.type_attr(), "' required for output '",
|
||||||
|
output_arg.name(), "' isn't a type attribute");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
state_->types.push_back(type_attr.getValue());
|
||||||
|
} else if (!output_arg.type_list_attr().empty()) {
|
||||||
|
// This is pointing to an attribute which is an array of types.
|
||||||
|
Attribute attr = attrs_[output_arg.type_list_attr()];
|
||||||
|
if (!attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Missing attribute '", output_arg.type_list_attr(),
|
||||||
|
"' required for output '", output_arg.name(), "'");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
|
||||||
|
if (!array_attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Attribute '", output_arg.type_list_attr(),
|
||||||
|
"' required for output '", output_arg.name(),
|
||||||
|
"' isn't an array attribute");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (Attribute attr : array_attr) {
|
||||||
|
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||||
|
if (!type_attr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Array Attribute '", output_arg.type_list_attr(),
|
||||||
|
"' required for output '", output_arg.name(),
|
||||||
|
"' has a non-Type element");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
state_->types.push_back(type_attr.getValue());
|
||||||
|
}
|
||||||
|
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
||||||
|
Type type;
|
||||||
|
Builder builder(context_);
|
||||||
|
s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
state_->types.push_back(type);
|
||||||
|
} else {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"No type fields in ", output_arg.ShortDebugString());
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
}
|
||||||
|
if (output_arg.is_ref()) {
|
||||||
|
// For all types that were added by this function call, make them refs.
|
||||||
|
for (Type& type : llvm::make_range(&state_->types[original_size],
|
||||||
|
state_->types.end())) {
|
||||||
|
type = AddRef(type, s);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return state_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Function* MlirFunction::GetTfFunction(TF_Status* s) {
|
||||||
|
PassManager pm(func_.getContext());
|
||||||
|
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
||||||
|
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||||
|
|
||||||
|
// In case of failure, the `diag_handler` converts MLIR errors emitted to
|
||||||
|
// the MLIRContext into a tensorflow::Status.
|
||||||
|
StatusScopedDiagnosticHandler diag_handler(func_.getContext());
|
||||||
|
LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
|
||||||
|
(void)result;
|
||||||
|
s->status = diag_handler.ConsumeStatus();
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
|
||||||
|
tensorflow::GraphExportConfig configs;
|
||||||
|
std::unique_ptr<TF_Function> tf_function(new TF_Function);
|
||||||
|
s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs,
|
||||||
|
&tf_function->fdef);
|
||||||
|
return tf_function.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op,
|
||||||
|
int num_inputs,
|
||||||
|
AbstractTensor* const* inputs,
|
||||||
|
OutputList* o, TF_Status* s) {
|
||||||
|
auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op);
|
||||||
|
if (mlir_op == nullptr) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SmallVector<Value, 8> operands;
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
auto* operand = dyncast<MlirTensor>(inputs[i]);
|
||||||
|
if (!operand) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Capturing eager tensors is not supported yet.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (operand->getValue().getContext() != context_.get()) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Capturing tensors from other context is not supported.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
operands.push_back(operand->getValue());
|
||||||
|
}
|
||||||
|
OperationState* state = mlir_op->Create(operands, s);
|
||||||
|
if (!s->status.ok() || !state) return;
|
||||||
|
Operation* op = builder_.createOperation(*state);
|
||||||
|
int num_results = op->getNumResults();
|
||||||
|
o->outputs.clear();
|
||||||
|
o->outputs.reserve(num_results);
|
||||||
|
for (Value result : op->getResults())
|
||||||
|
o->outputs.push_back(new MlirTensor(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype,
|
||||||
|
TF_Status* s) {
|
||||||
|
Type type;
|
||||||
|
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
|
||||||
|
builder_, &type);
|
||||||
|
if (!s->status.ok()) return nullptr;
|
||||||
|
return new MlirTensor(func_.getBody().front().addArgument(type));
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs,
|
||||||
|
TF_Status* s) {
|
||||||
|
Block& body = func_.getBody().front();
|
||||||
|
SmallVector<Value, 8> ret_operands;
|
||||||
|
for (AbstractTensor* output : outputs->outputs) {
|
||||||
|
auto* operand = dyncast<MlirTensor>(output);
|
||||||
|
if (!operand) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Capturing eager tensors is not supported yet.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (operand->getValue().getContext() != context_.get()) {
|
||||||
|
s->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Capturing tensors from other context is not supported.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ret_operands.push_back(operand->getValue());
|
||||||
|
}
|
||||||
|
builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
|
||||||
|
|
||||||
|
auto arg_types = llvm::to_vector<8>(body.getArgumentTypes());
|
||||||
|
auto result_types =
|
||||||
|
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
|
||||||
|
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
|
||||||
|
return new MlirFunction(std::move(context_), std::move(module_), func_);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
|
||||||
|
RegisterDialects();
|
||||||
|
return new MlirFunctionContext(fn_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end anonymous namespace
|
||||||
|
} // end namespace TF
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,31 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
|
|
||||||
|
using tensorflow::internal::ExecutionContext;
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Register the tracing implemented in this file as the default tracing engine.
|
||||||
|
static bool register_tracing = [] {
|
||||||
|
RegisterTracingEngineFactory("mlir", MlirTracingFactory);
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
|
||||||
|
} // namespace
|
|
@ -782,4 +782,22 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||||
return graphdef;
|
return graphdef;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
|
||||||
|
mlir::FuncOp func, const GraphExportConfig& configs,
|
||||||
|
FunctionDef* function_def) {
|
||||||
|
Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf");
|
||||||
|
FunctionDefLibrary flib;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
|
||||||
|
for (auto& func_def : flib.function()) {
|
||||||
|
if (func_def.signature().name() == func.getName()) {
|
||||||
|
*function_def = func_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Function couldn't be found in the FunctionDefLibrary after converting "
|
||||||
|
"from MLIR");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
@ -50,6 +51,12 @@ stream_executor::port::Status ConvertMlirToGraph(
|
||||||
stream_executor::port::Status ConvertMlirToGraph(
|
stream_executor::port::Status ConvertMlirToGraph(
|
||||||
mlir::ModuleOp module, const GraphExportConfig& configs,
|
mlir::ModuleOp module, const GraphExportConfig& configs,
|
||||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
|
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
|
||||||
|
|
||||||
|
// Converts an MLIR function and adds it to a FunctionLibraryDefinition.
|
||||||
|
stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
|
||||||
|
mlir::FuncOp func, const GraphExportConfig& configs,
|
||||||
|
FunctionDef* function_def);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_
|
||||||
|
|
Loading…
Reference in New Issue