Implement MLIR and graph tracing using abstract interfaces shared with immediate execution mode.
Enabled unified API test to run with TFRT. Replace dyn_cast with tensorflow::dyn_cast. TF_ExecuteOperation no longer takes a TF_ExecutionContext arg since the current impls tie the op builder to the creating context and it is not clear if we will ever support that API. PiperOrigin-RevId: 318203001 Change-Id: If5048b1404b87c809606c236419e1869630bcd46
This commit is contained in:
parent
86fc04ef9b
commit
23c45d4828
@ -158,9 +158,13 @@ cc_library(
|
|||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":abstract_context",
|
||||||
|
":abstract_operation",
|
||||||
|
":abstract_tensor_handle",
|
||||||
":c_api",
|
":c_api",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/core/platform:casts",
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
@ -541,6 +545,9 @@ tf_cuda_library(
|
|||||||
":abstract_operation",
|
":abstract_operation",
|
||||||
":abstract_context",
|
":abstract_context",
|
||||||
":abstract_tensor_handle",
|
":abstract_tensor_handle",
|
||||||
|
":immediate_execution_tensor_handle",
|
||||||
|
":immediate_execution_context",
|
||||||
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
@ -559,6 +566,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/types:variant",
|
"@com_google_absl//absl/types:variant",
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
],
|
],
|
||||||
}) + select({
|
}) + select({
|
||||||
"//tensorflow:with_xla_support": [
|
"//tensorflow:with_xla_support": [
|
||||||
@ -732,6 +740,10 @@ filegroup(
|
|||||||
],
|
],
|
||||||
exclude = [
|
exclude = [
|
||||||
"c_api_experimental.cc",
|
"c_api_experimental.cc",
|
||||||
|
"c_api_unified_experimental.cc",
|
||||||
|
"c_api_unified_experimental_eager.cc",
|
||||||
|
"c_api_unified_experimental_graph.cc",
|
||||||
|
"c_api_unified_experimental_internal.h",
|
||||||
"*test*",
|
"*test*",
|
||||||
"*dlpack*",
|
"*dlpack*",
|
||||||
],
|
],
|
||||||
|
@ -25,7 +25,7 @@ namespace tensorflow {
|
|||||||
// function.
|
// function.
|
||||||
class AbstractFunction {
|
class AbstractFunction {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
|
enum AbstractFunctionKind { kGraph, kMlir };
|
||||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -22,15 +22,17 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
using tensorflow::internal::OutputList;
|
|
||||||
using tensorflow::internal::unwrap;
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace internal {
|
namespace tracing {
|
||||||
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
|
||||||
|
|
||||||
static FactoriesMap& GetFactories() {
|
static FactoriesMap& GetFactories() {
|
||||||
static FactoriesMap* factories = new FactoriesMap;
|
static FactoriesMap* factories = new FactoriesMap;
|
||||||
@ -48,8 +50,8 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
|||||||
|
|
||||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||||
|
|
||||||
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
auto entry = GetFactories().find(default_factory);
|
auto entry = GetFactories().find(default_factory);
|
||||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||||
string msg = absl::StrCat(
|
string msg = absl::StrCat(
|
||||||
@ -70,7 +72,7 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace tracing
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
@ -83,43 +85,77 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
|||||||
//
|
//
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
|
using tensorflow::AbstractFunction;
|
||||||
|
using tensorflow::AbstractTensorHandle;
|
||||||
|
using tensorflow::DataType;
|
||||||
|
using tensorflow::dyn_cast;
|
||||||
|
using tensorflow::OutputList;
|
||||||
|
using tensorflow::Status;
|
||||||
|
using tensorflow::unwrap;
|
||||||
|
using tensorflow::wrap;
|
||||||
|
using tensorflow::tracing::CreateTracingExecutionContext;
|
||||||
|
using tensorflow::tracing::SetDefaultTracingEngine;
|
||||||
|
using tensorflow::tracing::TracingContext;
|
||||||
|
using tensorflow::tracing::TracingOperation;
|
||||||
|
using tensorflow::tracing::TracingTensorHandle;
|
||||||
|
|
||||||
void TF_SetTracingImplementation(const char* name) {
|
void TF_SetTracingImplementation(const char* name) {
|
||||||
tensorflow::internal::SetDefaultTracingEngine(name);
|
SetDefaultTracingEngine(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||||
// given tracing context.
|
// given tracing context.
|
||||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||||
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
return wrap(CreateTracingExecutionContext(fn_name, s));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||||
TF_OutputList* outputs, TF_Status* s) {
|
TF_OutputList* outputs, TF_Status* s) {
|
||||||
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
AbstractFunction* func;
|
||||||
|
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
|
||||||
|
if (!tracing_ctx) {
|
||||||
|
Set_TF_Status_from_Status(
|
||||||
|
s, tensorflow::errors::InvalidArgument(
|
||||||
|
"Only TracingContext can be converted into a function."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
|
||||||
TF_DeleteExecutionContext(ctx);
|
TF_DeleteExecutionContext(ctx);
|
||||||
return func;
|
return wrap(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||||
TF_DataType dtype, TF_Status* s) {
|
TF_DataType dtype, TF_Status* s) {
|
||||||
return wrap(unwrap(func)->AddParameter(dtype, s));
|
TracingTensorHandle* t;
|
||||||
|
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
||||||
|
if (!tracing_ctx) {
|
||||||
|
Set_TF_Status_from_Status(
|
||||||
|
s, tensorflow::errors::InvalidArgument(
|
||||||
|
"TF_AddFunctionParameter must be called on a TracingContext."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Set_TF_Status_from_Status(
|
||||||
|
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
|
||||||
|
return wrap(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
|
||||||
|
|
||||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||||
return wrap(unwrap(c)->CreateOperation());
|
return wrap((unwrap(c)->CreateOperation()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
|
||||||
|
|
||||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
|
||||||
|
|
||||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
unwrap(o)->expected_num_outputs = num_outputs;
|
unwrap(o)->expected_num_outputs = num_outputs;
|
||||||
|
unwrap(o)->outputs.clear();
|
||||||
|
unwrap(o)->outputs.resize(num_outputs);
|
||||||
}
|
}
|
||||||
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||||
return unwrap(o)->outputs.size();
|
return unwrap(o)->outputs.size();
|
||||||
@ -134,24 +170,46 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
|||||||
|
|
||||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
unwrap(op)->SetOpType(op_type, s);
|
Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
|
||||||
|
/*raw_device_name=*/nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
unwrap(op)->SetOpName(op_name, s);
|
TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
|
||||||
|
if (!tracing_op) {
|
||||||
|
Set_TF_Status_from_Status(
|
||||||
|
s, tensorflow::errors::InvalidArgument(
|
||||||
|
"TF_AbstractOpSetOpName must be called on a TracingOperation."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||||
TF_DataType value, TF_Status* s) {
|
TF_DataType value, TF_Status* s) {
|
||||||
unwrap(op)->SetAttrType(attr_name, value, s);
|
Status status =
|
||||||
|
unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
|
||||||
|
TF_SetStatus(s, static_cast<TF_Code>(status.code()),
|
||||||
|
status.error_message().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
TF_Status* s) {
|
||||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
|
for (int i = 0; i < num_inputs; i++) {
|
||||||
unwrap(o), s);
|
Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
|
||||||
|
if (TF_GetCode(s) != TF_OK) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int num_outputs = unwrap(o)->expected_num_outputs;
|
||||||
|
Set_TF_Status_from_Status(
|
||||||
|
s, unwrap(op)->Execute(
|
||||||
|
absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
|
||||||
|
unwrap(o)->outputs.data()),
|
||||||
|
unwrap(o)->outputs.size()),
|
||||||
|
&num_outputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||||
@ -161,5 +219,5 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
|||||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||||
TF_AbstractFunction* func,
|
TF_AbstractFunction* func,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,7 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
|||||||
// Any active tape will observe the effects of this execution.
|
// Any active tape will observe the effects of this execution.
|
||||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
TF_ExecutionContext* ctx, TF_Status* s);
|
TF_Status* s);
|
||||||
|
|
||||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||||
// context. The provided `ctx` is consumed by this API call and deleted.
|
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||||
@ -137,7 +137,8 @@ TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
|||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*,
|
||||||
|
TF_Status* s);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
|
@ -15,180 +15,68 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
// Simple wrapper over a TFE_TensorHandle
|
|
||||||
struct EagerTensor : public AbstractTensor {
|
|
||||||
TFE_TensorHandle* t = nullptr;
|
|
||||||
EagerTensor() : AbstractTensor(kKind) {}
|
|
||||||
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
|
||||||
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
|
||||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Simple wrapper over a TFE_Op
|
|
||||||
class EagerOp : public AbstractOp {
|
|
||||||
public:
|
|
||||||
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
|
||||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
|
||||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
|
||||||
}
|
|
||||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
|
||||||
// Name is ignored in eager mode.
|
|
||||||
}
|
|
||||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
|
||||||
TF_Status* s) override {
|
|
||||||
if (op_ == nullptr) {
|
|
||||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
|
||||||
"op_type must be specified before specifying attrs.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_OpSetAttrType(op_, attr_name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
~EagerOp() override { TFE_DeleteOp(op_); }
|
|
||||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class EagerContext; // For access to op_.
|
|
||||||
TFE_Op* op_ = nullptr;
|
|
||||||
TFE_Context* ctx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
|
|
||||||
class EagerContext : public ExecutionContext {
|
|
||||||
public:
|
|
||||||
EagerContext() : ExecutionContext(kKind) {}
|
|
||||||
|
|
||||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
|
||||||
eager_ctx_ = TFE_NewContext(options, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
|
||||||
return new EagerOp(eager_ctx_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
|
||||||
TF_Status* s) override {
|
|
||||||
auto* eager_op = dyncast<EagerOp>(op);
|
|
||||||
if (eager_op == nullptr) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto* tfe_op = eager_op->op_;
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
|
|
||||||
if (!eager_tensor) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
}
|
|
||||||
if (o->expected_num_outputs == -1) {
|
|
||||||
string msg =
|
|
||||||
"The number of outputs must be provided in eager mode. Use "
|
|
||||||
"TF_OutputListSetNumOutputs.";
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
|
||||||
int num_retvals = o->expected_num_outputs;
|
|
||||||
retvals.resize(num_retvals);
|
|
||||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_retvals);
|
|
||||||
for (int i = 0; i < num_retvals; ++i) {
|
|
||||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Can't add function parameter on an eager context.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Can't use finalize function on an eager context.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
|
||||||
auto* func = afunc->GetTfFunction(s);
|
|
||||||
if (!func) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
|
||||||
|
|
||||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
|
|
||||||
TF_ExecutionContext* ctx);
|
|
||||||
TFE_Context* eager_ctx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace internal
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Public C API entry points
|
// Public C API entry points
|
||||||
// These are only the entry points specific to the Eager API.
|
// These are only the entry points specific to the Eager API.
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
using tensorflow::internal::dyncast;
|
using tensorflow::AbstractContext;
|
||||||
using tensorflow::internal::unwrap;
|
using tensorflow::AbstractTensorHandle;
|
||||||
|
using tensorflow::dyn_cast;
|
||||||
|
using tensorflow::ImmediateExecutionContext;
|
||||||
|
using tensorflow::ImmediateExecutionTensorHandle;
|
||||||
|
using tensorflow::string;
|
||||||
|
using tensorflow::unwrap;
|
||||||
|
using tensorflow::wrap;
|
||||||
|
using tensorflow::strings::StrCat;
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
auto* ctx = new tensorflow::internal::EagerContext();
|
TFE_Context* c_ctx = TFE_NewContext(options, s);
|
||||||
ctx->Build(options, s);
|
if (TF_GetCode(s) != TF_OK) {
|
||||||
return wrap(ctx);
|
return nullptr;
|
||||||
|
}
|
||||||
|
return wrap(static_cast<AbstractContext*>(unwrap(c_ctx)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
return wrap(new tensorflow::internal::EagerTensor(t));
|
return wrap(static_cast<AbstractTensorHandle*>(unwrap(t)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
|
auto handle = dyn_cast<ImmediateExecutionTensorHandle>(unwrap(at));
|
||||||
if (!eager_tensor) {
|
if (!handle) {
|
||||||
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
|
string msg =
|
||||||
reinterpret_cast<uintptr_t>(at));
|
StrCat("Not an eager tensor handle.", reinterpret_cast<uintptr_t>(at));
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return eager_tensor->t;
|
return wrap(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx,
|
||||||
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
|
TF_Status* s) {
|
||||||
if (!eager_ctx) return nullptr;
|
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(unwrap(ctx));
|
||||||
return eager_ctx->eager_ctx_;
|
if (!imm_ctx) {
|
||||||
|
string msg =
|
||||||
|
StrCat("Not an eager context.", reinterpret_cast<uintptr_t>(ctx));
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return wrap(imm_ctx);
|
||||||
}
|
}
|
||||||
|
@ -18,77 +18,198 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
using tensorflow::dyn_cast;
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace internal {
|
namespace tracing {
|
||||||
|
namespace graph {
|
||||||
|
|
||||||
class GraphContext;
|
class GraphContext;
|
||||||
|
class GraphOperation;
|
||||||
|
class GraphTensor;
|
||||||
|
|
||||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||||
// into the list of outputs for the operation.
|
// into the list of outputs for the operation.
|
||||||
struct GraphTensor : public AbstractTensor {
|
class GraphTensor : public TracingTensorHandle {
|
||||||
TF_Output output{};
|
public:
|
||||||
GraphContext* ctx = nullptr;
|
explicit GraphTensor(TF_Output output)
|
||||||
GraphTensor() : AbstractTensor(kKind) {}
|
: TracingTensorHandle(kGraph), output_(output) {}
|
||||||
GraphTensor(TF_Output output, GraphContext* ctx)
|
void Release() override { delete this; }
|
||||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
TF_Output output_;
|
||||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
|
||||||
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractTensorHandle* ptr) {
|
||||||
|
return ptr->getKind() == kGraph;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// GraphOp wraps and populate a TF_OperationDescription.
|
// GraphOperation wraps and populates a TF_OperationDescription.
|
||||||
class GraphOp : public AbstractOp {
|
class GraphOperation : public TracingOperation {
|
||||||
public:
|
public:
|
||||||
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
|
||||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
void Release() override { delete this; }
|
||||||
|
Status Reset(const char* op, const char* raw_device_name) override {
|
||||||
if (op_) {
|
if (op_) {
|
||||||
TF_SetStatus(
|
return errors::FailedPrecondition("Reset called on already built op.");
|
||||||
s, TF_FAILED_PRECONDITION,
|
|
||||||
strings::StrCat("SetOpType called on already built op.").c_str());
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
if (op_name_ != nullptr) {
|
if (raw_device_name) {
|
||||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
device_name_ = raw_device_name;
|
||||||
op_name_ = nullptr;
|
|
||||||
} else {
|
|
||||||
op_type_ = op_type;
|
|
||||||
}
|
}
|
||||||
|
op_type_ = op;
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
Status SetOpName(const char* const op_name) override {
|
||||||
if (op_) {
|
if (op_) {
|
||||||
TF_SetStatus(
|
return errors::FailedPrecondition(
|
||||||
s, TF_FAILED_PRECONDITION,
|
"SetOpName called on already built op.");
|
||||||
strings::StrCat("SetOpName called on already built op.").c_str());
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
if (op_type_ != nullptr) {
|
if (op_type_.empty()) {
|
||||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
return errors::FailedPrecondition(
|
||||||
op_type_ = nullptr;
|
"GraphOperation::Reset must be called before calling SetOpName.");
|
||||||
} else {
|
|
||||||
op_name_ = op_name;
|
|
||||||
}
|
}
|
||||||
|
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
const string& Name() const override { return op_type_; }
|
||||||
TF_Status* s) override {
|
const string& DeviceName() const override { return device_name_; }
|
||||||
if (!op_) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_FAILED_PRECONDITION,
|
|
||||||
"op_type and op_name must be specified before specifying attrs.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TF_SetAttrType(op_.get(), attr_name, value);
|
|
||||||
}
|
|
||||||
~GraphOp() override {}
|
|
||||||
|
|
||||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
Status SetDeviceName(const char* name) override {
|
||||||
|
// TODO(srbs): Implement this.
|
||||||
|
device_name_ = name;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AddInput(AbstractTensorHandle* input) override {
|
||||||
|
GraphTensor* t = dyn_cast<GraphTensor>(input);
|
||||||
|
if (!t) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Unable to cast input to GraphTensor");
|
||||||
|
}
|
||||||
|
TF_AddInput(op_.get(), t->output_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"AddInputList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) override {
|
||||||
|
auto* tf_opdesc = op_.release();
|
||||||
|
if (tf_opdesc == nullptr) {
|
||||||
|
return errors::InvalidArgument("AbstractOp is incomplete.");
|
||||||
|
}
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||||
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
*num_retvals = TF_OperationNumOutputs(operation);
|
||||||
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
|
retvals[i] = new GraphTensor({operation, i});
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrString has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrInt(const char* attr_name, int64_t value) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrInt has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrFloat(const char* attr_name, float value) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFloat has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrBool(const char* attr_name, bool value) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrBool has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrType(const char* const attr_name, DataType value) override {
|
||||||
|
if (!op_) {
|
||||||
|
return Status(
|
||||||
|
error::Code::FAILED_PRECONDITION,
|
||||||
|
"op_type and op_name must be specified before specifying attrs.");
|
||||||
|
}
|
||||||
|
op_->node_builder.Attr(attr_name, value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||||
|
const int num_dims) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrShape has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrFunction(const char* attr_name,
|
||||||
|
const AbstractOperation* value) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFunction has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||||
|
size_t length) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFunctionName has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrTensor(const char* attr_name,
|
||||||
|
AbstractTensorInterface* tensor) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrTensor has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||||
|
const size_t* lengths, int num_values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrStringList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||||
|
int num_values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFloatList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||||
|
int num_values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrIntList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||||
|
int num_values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrTypeList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||||
|
int num_values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrBoolList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrShapeList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status SetAttrFunctionList(
|
||||||
|
const char* attr_name,
|
||||||
|
absl::Span<const AbstractOperation*> values) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFunctionList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractOperation* ptr) {
|
||||||
|
return ptr->getKind() == kGraph;
|
||||||
|
}
|
||||||
|
~GraphOperation() override {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class GraphContext; // For access to op_.
|
friend class GraphContext; // For access to op_.
|
||||||
@ -96,123 +217,109 @@ class GraphOp : public AbstractOp {
|
|||||||
std::unique_ptr<TF_OperationDescription> op_;
|
std::unique_ptr<TF_OperationDescription> op_;
|
||||||
// Hold `op_type` and `op_name` till both are available since we need both
|
// Hold `op_type` and `op_name` till both are available since we need both
|
||||||
// to build a graph operation.
|
// to build a graph operation.
|
||||||
const char* op_type_ = nullptr;
|
string op_type_;
|
||||||
const char* op_name_ = nullptr;
|
const char* op_name_ = nullptr;
|
||||||
|
// TODO(srbs): Use this.
|
||||||
|
string device_name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// GraphFunction is a thin wrapper over a TF_Function.
|
// GraphFunction is a thin wrapper over a TF_Function.
|
||||||
struct GraphFunction : public AbstractFunction {
|
struct GraphFunction : public AbstractFunction {
|
||||||
TF_Function* func = nullptr;
|
TF_Function* func = nullptr;
|
||||||
GraphFunction() : AbstractFunction(kKind) {}
|
GraphFunction() : AbstractFunction(kGraph) {}
|
||||||
explicit GraphFunction(TF_Function* func)
|
explicit GraphFunction(TF_Function* func)
|
||||||
: AbstractFunction(kKind), func(func) {}
|
: AbstractFunction(kGraph), func(func) {}
|
||||||
~GraphFunction() override {
|
~GraphFunction() override {
|
||||||
if (func) TF_DeleteFunction(func);
|
if (func) TF_DeleteFunction(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
Status GetFunctionDef(FunctionDef** fdef) override {
|
||||||
|
*fdef = &func->fdef;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractFunction* ptr) {
|
||||||
|
return ptr->getKind() == kGraph;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||||
// "execution" of operation, i.e. adding them to the function.
|
// "execution" of operation, i.e. adding them to the function.
|
||||||
class GraphContext : public ExecutionContext {
|
class GraphContext : public TracingContext {
|
||||||
public:
|
public:
|
||||||
explicit GraphContext(const char* name)
|
explicit GraphContext(const char* name)
|
||||||
: ExecutionContext(kKind),
|
: TracingContext(kGraph),
|
||||||
graph_(new TF_Graph(), TF_DeleteGraph),
|
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||||
name_(name) {}
|
name_(name) {}
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
void Release() override { delete this; }
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
|
||||||
return new GraphOp(graph_.get());
|
TracingOperation* CreateOperation() override {
|
||||||
|
return new GraphOperation(graph_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
auto operation = CreateOperation();
|
||||||
TF_Status* s) override {
|
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
||||||
auto* graph_op = dyncast<GraphOp>(op);
|
TF_RETURN_IF_ERROR(
|
||||||
if (graph_op == nullptr) {
|
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
||||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
int num_outputs = 1;
|
||||||
return;
|
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||||
|
TF_RETURN_IF_ERROR(operation->Execute(
|
||||||
|
absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
|
||||||
|
|
||||||
|
if (num_outputs != 1) {
|
||||||
|
return errors::Internal("Expected 1 output but found ", num_outputs);
|
||||||
}
|
}
|
||||||
auto* tf_opdesc = graph_op->op_.release();
|
auto* t = dyn_cast<GraphTensor>(outputs[0]);
|
||||||
if (tf_opdesc == nullptr) {
|
if (!t) {
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
return tensorflow::errors::InvalidArgument(
|
||||||
return;
|
"Unable to cast input to GraphTensor");
|
||||||
}
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
|
||||||
if (!graph_tensor) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Capturing eager tensors is not supported yet.");
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
if (graph_tensor->ctx != this) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_INVALID_ARGUMENT,
|
|
||||||
"Capturing tensors from other graphs is not supported yet.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TF_AddInput(tf_opdesc, graph_tensor->output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
|
||||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
|
||||||
graph_op->op_ = nullptr;
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
int num_outputs = TF_OperationNumOutputs(operation);
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_outputs);
|
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
|
||||||
}
|
}
|
||||||
|
inputs_.push_back(t->output_);
|
||||||
|
*output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
Status Finalize(OutputList* outputs, AbstractFunction** f) override {
|
||||||
TF_OperationDescription* opdesc =
|
|
||||||
TF_NewOperation(graph_.get(), "Placeholder",
|
|
||||||
absl::StrCat("_input_", inputs_.size()).c_str());
|
|
||||||
TF_SetAttrType(opdesc, "dtype", dtype);
|
|
||||||
auto* operation = TF_FinishOperation(opdesc, s);
|
|
||||||
if (!s->status.ok()) return nullptr;
|
|
||||||
|
|
||||||
inputs_.push_back(TF_Output{operation, 0});
|
|
||||||
return new GraphTensor(inputs_.back(), this);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
|
||||||
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||||
std::vector<TF_Output> graph_outputs;
|
std::vector<TF_Output> graph_outputs;
|
||||||
graph_outputs.reserve(outputs->outputs.size());
|
graph_outputs.reserve(outputs->outputs.size());
|
||||||
for (AbstractTensor* abstract_output : outputs->outputs) {
|
for (auto* abstract_output : outputs->outputs) {
|
||||||
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
|
||||||
if (!output) {
|
if (!output) {
|
||||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
return errors::Unimplemented(
|
||||||
"Returning a non-graph tensor from a function has not "
|
"Returning a non-graph tensor from a function has not "
|
||||||
"been implemented yet.");
|
"been implemented yet.");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
graph_outputs.push_back(output->output);
|
graph_outputs.push_back(output->output_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto s = TF_NewStatus();
|
||||||
func->func = TF_GraphToFunction(
|
func->func = TF_GraphToFunction(
|
||||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||||
return func.release();
|
TF_DeleteStatus(s);
|
||||||
|
*f = func.release();
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
Status RegisterFunction(AbstractFunction* func) override {
|
||||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
return errors::Unimplemented(
|
||||||
"Registering graph functions has not been implemented yet.");
|
"Registering graph functions has not been implemented yet.");
|
||||||
}
|
}
|
||||||
|
|
||||||
~GraphContext() override {}
|
Status RemoveFunction(const string& func) override {
|
||||||
|
return errors::Unimplemented(
|
||||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
"GraphContext::RemoveFunction has not been implemented yet.");
|
||||||
|
}
|
||||||
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractContext* ptr) {
|
||||||
|
return ptr->getKind() == kGraph;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||||
@ -220,7 +327,7 @@ class GraphContext : public ExecutionContext {
|
|||||||
const char* name_;
|
const char* name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||||
return new GraphContext(name);
|
return new GraphContext(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,5 +338,6 @@ static bool register_tracing = [] {
|
|||||||
return true;
|
return true;
|
||||||
}();
|
}();
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace graph
|
||||||
|
} // namespace tracing
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,6 +19,10 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_operation.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
@ -26,7 +30,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace internal {
|
|
||||||
|
// Represents the results of the execution of an operation.
|
||||||
|
struct OutputList {
|
||||||
|
std::vector<AbstractTensorHandle*> outputs;
|
||||||
|
int expected_num_outputs = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace tracing {
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Implementation detail for the unified execution APIs for Eager and tracing
|
// Implementation detail for the unified execution APIs for Eager and tracing
|
||||||
@ -37,165 +48,75 @@ namespace internal {
|
|||||||
// `c_api_unified_experimental.h` header.
|
// `c_api_unified_experimental.h` header.
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
// We can't depend on C++ rtti, but we still want to be able to have a safe
|
// Represents either a MlirTensor or a GraphTensor.
|
||||||
// dynamic_cast to provide diagnostics to the user when the API is misused.
|
|
||||||
// Instead we model RTTI by listing all the possible subclasses for each
|
|
||||||
// abstract base. Each subclass initializes the base class with the right
|
|
||||||
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
|
|
||||||
// utility.
|
|
||||||
template <typename T, typename S>
|
|
||||||
T* dyncast(S source) {
|
|
||||||
if (source->getKind() != T::kKind) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return tensorflow::down_cast<T*>(source);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Represents either an EagerTensor or a GraphTensor.
|
|
||||||
// This base class does not expose any public methods other than to distinguish
|
// This base class does not expose any public methods other than to distinguish
|
||||||
// which subclass it actually is. The user is responsible to use the right
|
// which subclass it actually is. The user is responsible to use the right
|
||||||
// type of AbstractTensor in their context (do not pass an EagerTensor to a
|
// type of AbstractTensor in their context (do not pass an MlirTensor to a
|
||||||
// GraphContext and vice-versa).
|
// GraphContext and vice-versa).
|
||||||
class AbstractTensor {
|
class TracingTensorHandle : public AbstractTensorHandle {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
|
explicit TracingTensorHandle(AbstractTensorHandleKind kind)
|
||||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
: AbstractTensorHandle(kind) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Returns which subclass is this instance of.
|
// For LLVM style RTTI.
|
||||||
AbstractTensorKind getKind() const { return kind_; }
|
static bool classof(const AbstractTensorHandle* ptr) {
|
||||||
virtual ~AbstractTensor() = default;
|
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
|
||||||
|
}
|
||||||
private:
|
|
||||||
const AbstractTensorKind kind_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Represents the results of the execution of an operation.
|
|
||||||
struct OutputList {
|
|
||||||
std::vector<AbstractTensor*> outputs;
|
|
||||||
int expected_num_outputs = -1;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Holds the result of tracing a function.
|
|
||||||
class AbstractFunction {
|
|
||||||
protected:
|
|
||||||
enum AbstractFunctionKind { kGraphFunc };
|
|
||||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
|
||||||
|
|
||||||
public:
|
|
||||||
// Returns which subclass is this instance of.
|
|
||||||
AbstractFunctionKind getKind() const { return kind_; }
|
|
||||||
virtual ~AbstractFunction() = default;
|
|
||||||
|
|
||||||
// Temporary API till we figure the right abstraction for AbstractFunction.
|
|
||||||
// At the moment both Eager and Graph needs access to a "TF_Function" object.
|
|
||||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const AbstractFunctionKind kind_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// An abstract operation describes an operation by its type, name, and
|
// An abstract operation describes an operation by its type, name, and
|
||||||
// attributes. It can be "executed" by the context with some input tensors.
|
// attributes. It can be "executed" by the context with some input tensors.
|
||||||
// It is allowed to reusing the same abstract operation for multiple execution
|
// It is allowed to reusing the same abstract operation for multiple execution
|
||||||
// on a given context, with the same or different input tensors.
|
// on a given context, with the same or different input tensors.
|
||||||
class AbstractOp {
|
class TracingOperation : public AbstractOperation {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
|
explicit TracingOperation(AbstractOperationKind kind)
|
||||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
: AbstractOperation(kind) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Returns which subclass is this instance of.
|
|
||||||
AbstractOpKind getKind() const { return kind_; }
|
|
||||||
virtual ~AbstractOp() = default;
|
|
||||||
|
|
||||||
// Sets the type of the operation (for example `AddV2`).
|
|
||||||
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
|
|
||||||
|
|
||||||
// Sets the name of the operation: this is an optional identifier that is
|
// Sets the name of the operation: this is an optional identifier that is
|
||||||
// not intended to carry semantics and preserved/propagated without
|
// not intended to carry semantics and preserved/propagated without
|
||||||
// guarantees.
|
// guarantees.
|
||||||
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
|
virtual Status SetOpName(const char* op_name) = 0;
|
||||||
|
|
||||||
// Add a `TypeAttribute` on the operation.
|
// For LLVM style RTTI.
|
||||||
virtual void SetAttrType(const char* attr_name, TF_DataType value,
|
static bool classof(const AbstractOperation* ptr) {
|
||||||
TF_Status* s) = 0;
|
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
|
||||||
|
}
|
||||||
private:
|
|
||||||
const AbstractOpKind kind_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// This holds the context for the execution: dispatching operations either to an
|
// This holds the context for the execution: dispatching operations either to an
|
||||||
// eager implementation or to a graph implementation.
|
// MLIR implementation or to a graph implementation.
|
||||||
struct ExecutionContext {
|
class TracingContext : public AbstractContext {
|
||||||
protected:
|
protected:
|
||||||
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
|
explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
|
||||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Returns which subclass is this instance of.
|
|
||||||
ExecutionContextKind getKind() const { return k; }
|
|
||||||
virtual ~ExecutionContext() = default;
|
|
||||||
|
|
||||||
// Executes the operation on the provided inputs and populate the OutputList
|
|
||||||
// with the results. The input tensors must match the current context.
|
|
||||||
// The effect of "executing" an operation depends on the context: in an Eager
|
|
||||||
// context it will dispatch it to the runtime for execution, while in a
|
|
||||||
// tracing context it will add the operation to the current function.
|
|
||||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
|
||||||
TF_Status* s) = 0;
|
|
||||||
|
|
||||||
// Creates an empty AbstractOperation suitable to use with this context.
|
|
||||||
virtual AbstractOp* CreateOperation() = 0;
|
|
||||||
|
|
||||||
// Add a function parameter and return the corresponding tensor.
|
// Add a function parameter and return the corresponding tensor.
|
||||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
|
||||||
// it'll always error out with an eager context.
|
|
||||||
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
|
||||||
|
|
||||||
// Finalize this context and make a function out of it. The context is in a
|
// Finalize this context and make a function out of it. The context is in a
|
||||||
// invalid state after this call and must be destroyed.
|
// invalid state after this call and must be destroyed.
|
||||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
|
||||||
// it'll always error out with an eager context.
|
|
||||||
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
|
||||||
|
|
||||||
// Registers a functions with this context, after this the function is
|
// For LLVM style RTTI.
|
||||||
// available to be called/referenced by its name in this context.
|
static bool classof(const AbstractContext* ptr) {
|
||||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
|
||||||
|
}
|
||||||
private:
|
|
||||||
const ExecutionContextKind k;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||||
void SetDefaultTracingEngine(const char* name);
|
void SetDefaultTracingEngine(const char* name);
|
||||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||||
FactoryFunction factory);
|
FactoryFunction factory);
|
||||||
|
} // namespace tracing
|
||||||
|
|
||||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
|
||||||
// C++ implementation, and back.
|
DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
|
||||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
|
||||||
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
|
||||||
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
|
||||||
} \
|
|
||||||
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
|
||||||
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
|
||||||
} \
|
|
||||||
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
|
||||||
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
|
||||||
} \
|
|
||||||
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
|
||||||
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
|
||||||
}
|
|
||||||
|
|
||||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
|
||||||
|
|
||||||
} // namespace internal
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
@ -29,15 +30,19 @@ using tensorflow::string;
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
class UnifiedCAPI
|
||||||
|
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
void SetUp() override {
|
||||||
|
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(UnifiedCAPI, TestBasicEager) {
|
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
@ -45,7 +50,8 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build an abstract input tensor.
|
// Build an abstract input tensor.
|
||||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
TF_AbstractTensor* at =
|
TF_AbstractTensor* at =
|
||||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||||
@ -63,7 +69,7 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Execute.
|
// Execute.
|
||||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
TF_ExecuteOperation(op, 2, inputs, o, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Clean up operation and inputs.
|
// Clean up operation and inputs.
|
||||||
@ -109,9 +115,11 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
|||||||
// Build inputs and outputs.
|
// Build inputs and outputs.
|
||||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Execute.
|
// Execute.
|
||||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Clean up operation and inputs.
|
// Clean up operation and inputs.
|
||||||
@ -123,6 +131,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
|||||||
|
|
||||||
// Build eager context.
|
// Build eager context.
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||||
TF_ExecutionContext* eager_execution_ctx =
|
TF_ExecutionContext* eager_execution_ctx =
|
||||||
TF_NewEagerExecutionContext(opts, status.get());
|
TF_NewEagerExecutionContext(opts, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
@ -137,16 +146,14 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
|||||||
|
|
||||||
// Build an abstract input tensor.
|
// Build an abstract input tensor.
|
||||||
TFE_Context* eager_ctx =
|
TFE_Context* eager_ctx =
|
||||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
TF_AbstractTensor* input_t =
|
TF_AbstractTensor* input_t =
|
||||||
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
|
|
||||||
status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
||||||
@ -195,8 +202,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
// Trace the operation now (create a node in the graph).
|
// Trace the operation now (create a node in the graph).
|
||||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_DeleteAbstractOp(add_op);
|
TF_DeleteAbstractOp(add_op);
|
||||||
// Extract the resulting tensor.
|
// Extract the resulting tensor.
|
||||||
@ -215,8 +224,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
// Trace the operation now (create a node in the graph).
|
// Trace the operation now (create a node in the graph).
|
||||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_DeleteAbstractOp(add_op);
|
TF_DeleteAbstractOp(add_op);
|
||||||
// Extract the resulting tensor.
|
// Extract the resulting tensor.
|
||||||
@ -256,6 +267,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
|
|
||||||
// Build eager context.
|
// Build eager context.
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||||
TF_ExecutionContext* eager_execution_ctx =
|
TF_ExecutionContext* eager_execution_ctx =
|
||||||
TF_NewEagerExecutionContext(opts, s);
|
TF_NewEagerExecutionContext(opts, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
@ -273,7 +285,8 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
std::vector<TF_AbstractTensor*> func_args;
|
std::vector<TF_AbstractTensor*> func_args;
|
||||||
{
|
{
|
||||||
TFE_Context* eager_ctx =
|
TFE_Context* eager_ctx =
|
||||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
@ -286,7 +299,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||||
eager_execution_ctx, s);
|
s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_DeleteAbstractOp(fn_op);
|
TF_DeleteAbstractOp(fn_op);
|
||||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||||
@ -314,20 +327,21 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
TF_DeleteAbstractFunction(func);
|
TF_DeleteAbstractFunction(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
TF_AbstractFunction* f = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||||
ASSERT_EQ(nullptr, func);
|
ASSERT_EQ(nullptr, f);
|
||||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
@ -348,7 +362,7 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
|||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpNameAfterFinishingOpBuildingRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
@ -369,116 +383,44 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
|||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
|
||||||
// Build an Eager context.
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build an Eager operation.
|
|
||||||
auto* op = TF_NewAbstractOp(ctx);
|
|
||||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build an abstract input tensor.
|
|
||||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
|
||||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
|
||||||
TF_AbstractTensor* at =
|
|
||||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build inputs and outputs.
|
|
||||||
TF_AbstractTensor* inputs[2] = {at, at};
|
|
||||||
TF_OutputList* o = TF_NewOutputList();
|
|
||||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build a Graph context.
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Execute eager op using graph context.
|
|
||||||
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
|
||||||
|
|
||||||
// Clean up operation and inputs.
|
|
||||||
TF_DeleteAbstractOp(op);
|
|
||||||
TF_DeleteAbstractTensor(at);
|
|
||||||
|
|
||||||
TF_DeleteOutputList(o);
|
|
||||||
TF_DeleteExecutionContext(ctx);
|
|
||||||
TF_DeleteExecutionContext(graph_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Add a placeholder to the graph.
|
// Add a placeholder to the graph.
|
||||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
auto placeholder_t =
|
||||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
||||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build inputs and outputs.
|
|
||||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
|
||||||
|
|
||||||
// Execute.
|
|
||||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
|
||||||
graph_ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
|
||||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
|
||||||
|
|
||||||
// Delete placeholder op.
|
|
||||||
TF_DeleteAbstractOp(placeholder_op);
|
|
||||||
|
|
||||||
// Build an abstract operation.
|
|
||||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
|
||||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build inputs and outputs.
|
|
||||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
|
||||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
|
||||||
|
|
||||||
// Build eager context.
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TF_ExecutionContext* eager_execution_ctx =
|
|
||||||
TF_NewEagerExecutionContext(opts, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
// Execute.
|
|
||||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
|
|
||||||
status.get());
|
|
||||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
// Clean up operation and inputs.
|
|
||||||
TF_DeleteAbstractTensor(placeholder_t);
|
TF_DeleteAbstractTensor(placeholder_t);
|
||||||
TF_DeleteAbstractOp(add_op);
|
|
||||||
TF_DeleteOutputList(add_outputs);
|
|
||||||
TF_DeleteOutputList(placeholder_outputs);
|
|
||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TF_ExecutionContextGetTFEContext(graph_ctx, status.get());
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
|
}
|
||||||
|
#ifdef PLATFORM_GOOGLE
|
||||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||||
::testing::Values("graphdef", "mlir"));
|
::testing::Combine(::testing::Values("graphdef",
|
||||||
|
"mlir"),
|
||||||
|
::testing::Values(true, false)));
|
||||||
|
#else
|
||||||
|
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||||
|
::testing::Combine(::testing::Values("graphdef",
|
||||||
|
"mlir"),
|
||||||
|
::testing::Values(false)));
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,8 +23,12 @@ tf_cuda_library(
|
|||||||
copts = tf_copts() + tfe_xla_copts(),
|
copts = tf_copts() + tfe_xla_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:tensor_interface",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/c:tf_status_internal",
|
"//tensorflow/c:tf_status_internal",
|
||||||
|
"//tensorflow/c/eager:abstract_context",
|
||||||
|
"//tensorflow/c/eager:abstract_operation",
|
||||||
|
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_internal",
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
"//tensorflow/c/eager:c_api_unified_internal",
|
"//tensorflow/c/eager:c_api_unified_internal",
|
||||||
@ -35,6 +39,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -26,14 +26,19 @@ limitations under the License.
|
|||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_operation.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
|
#include "tensorflow/c/tensor_interface.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/c/tf_status_internal.h"
|
#include "tensorflow/c/tf_status_internal.h"
|
||||||
@ -47,16 +52,21 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
using tensorflow::internal::AbstractFunction;
|
using tensorflow::AbstractFunction;
|
||||||
using tensorflow::internal::AbstractOp;
|
using tensorflow::AbstractOperation;
|
||||||
using tensorflow::internal::AbstractTensor;
|
using tensorflow::AbstractTensorHandle;
|
||||||
using tensorflow::internal::dyncast;
|
using tensorflow::AbstractTensorInterface;
|
||||||
using tensorflow::internal::ExecutionContext;
|
using tensorflow::dyn_cast;
|
||||||
using tensorflow::internal::OutputList;
|
using tensorflow::OutputList;
|
||||||
|
using tensorflow::string;
|
||||||
|
using tensorflow::tracing::TracingContext;
|
||||||
|
using tensorflow::tracing::TracingOperation;
|
||||||
|
using tensorflow::tracing::TracingTensorHandle;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -78,43 +88,104 @@ Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
class MlirTensor : public AbstractTensor {
|
class MlirTensor : public TracingTensorHandle {
|
||||||
public:
|
public:
|
||||||
explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {}
|
explicit MlirTensor(Value value)
|
||||||
|
: TracingTensorHandle(kMlir), value_(value) {}
|
||||||
|
|
||||||
|
void Release() override { delete this; }
|
||||||
|
|
||||||
Value getValue() { return value_; }
|
Value getValue() { return value_; }
|
||||||
|
|
||||||
static constexpr AbstractTensorKind kKind = kMlirTensor;
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractTensorHandle* ptr) {
|
||||||
|
return ptr->getKind() == kMlir;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Value value_;
|
Value value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirAbstractOp : public AbstractOp {
|
class MlirFunctionContext;
|
||||||
|
|
||||||
|
class MlirAbstractOp : public TracingOperation {
|
||||||
public:
|
public:
|
||||||
explicit MlirAbstractOp(MLIRContext* context)
|
explicit MlirAbstractOp(MLIRContext* context,
|
||||||
: AbstractOp(kKind), context_(context) {}
|
MlirFunctionContext* function_context)
|
||||||
|
: TracingOperation(kMlir),
|
||||||
|
context_(context),
|
||||||
|
function_context_(function_context) {}
|
||||||
|
|
||||||
void SetOpType(const char* op_type, TF_Status* s) override;
|
void Release() override { delete this; }
|
||||||
|
|
||||||
void SetAttrType(const char* attr_name, TF_DataType dtype,
|
Status Reset(const char* op, const char* raw_device_name) override;
|
||||||
TF_Status* s) override;
|
|
||||||
|
|
||||||
void SetOpName(const char* const op_name, TF_Status* s) override;
|
const string& Name() const override;
|
||||||
|
|
||||||
|
const string& DeviceName() const override;
|
||||||
|
|
||||||
|
Status SetDeviceName(const char* name) override;
|
||||||
|
|
||||||
|
Status AddInput(AbstractTensorHandle* input) override;
|
||||||
|
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
|
||||||
|
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) override;
|
||||||
|
|
||||||
|
Status SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||||
|
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||||
|
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||||
|
Status SetAttrType(const char* attr_name,
|
||||||
|
tensorflow::DataType dtype) override;
|
||||||
|
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||||
|
const int num_dims) override;
|
||||||
|
Status SetAttrFunction(const char* attr_name,
|
||||||
|
const AbstractOperation* value) override;
|
||||||
|
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrTensor(const char* attr_name,
|
||||||
|
AbstractTensorInterface* tensor) override;
|
||||||
|
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||||
|
const size_t* lengths, int num_values) override;
|
||||||
|
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrTypeList(const char* attr_name,
|
||||||
|
const tensorflow::DataType* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) override;
|
||||||
|
Status SetAttrFunctionList(
|
||||||
|
const char* attr_name,
|
||||||
|
absl::Span<const AbstractOperation*> values) override;
|
||||||
|
|
||||||
|
Status SetOpName(const char* const op_name) override;
|
||||||
|
|
||||||
MLIRContext* GetContext() { return context_; }
|
MLIRContext* GetContext() { return context_; }
|
||||||
|
|
||||||
Type AddRef(Type type, TF_Status* s);
|
Status AddRef(Type type, Type* output_type);
|
||||||
|
|
||||||
OperationState* Create(ArrayRef<Value> operands, TF_Status* s);
|
Status Create(ArrayRef<Value> operands, OperationState**);
|
||||||
|
|
||||||
static constexpr AbstractOpKind kKind = kMlirOp;
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractOperation* ptr) {
|
||||||
|
return ptr->getKind() == kMlir;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MLIRContext* context_;
|
MLIRContext* context_;
|
||||||
|
MlirFunctionContext* function_context_;
|
||||||
|
SmallVector<Value, 8> operands_;
|
||||||
llvm::StringMap<Attribute> attrs_;
|
llvm::StringMap<Attribute> attrs_;
|
||||||
std::unique_ptr<OperationState> state_;
|
std::unique_ptr<OperationState> state_;
|
||||||
const char* op_name_ = nullptr;
|
const char* op_name_ = nullptr;
|
||||||
|
string tf_op_type_;
|
||||||
|
// TODO(srbs): Use this.
|
||||||
|
string device_name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MlirFunction is a thin wrapper over a FuncOp.
|
// MlirFunction is a thin wrapper over a FuncOp.
|
||||||
@ -122,14 +193,17 @@ class MlirFunction : public AbstractFunction {
|
|||||||
public:
|
public:
|
||||||
explicit MlirFunction(std::unique_ptr<MLIRContext> context,
|
explicit MlirFunction(std::unique_ptr<MLIRContext> context,
|
||||||
OwningModuleRef module, FuncOp func)
|
OwningModuleRef module, FuncOp func)
|
||||||
: AbstractFunction(kKind),
|
: AbstractFunction(kMlir),
|
||||||
context_(std::move(context)),
|
context_(std::move(context)),
|
||||||
module_(std::move(module)),
|
module_(std::move(module)),
|
||||||
func_(func) {}
|
func_(func) {}
|
||||||
|
|
||||||
TF_Function* GetTfFunction(TF_Status* s) override;
|
Status GetFunctionDef(tensorflow::FunctionDef** f) override;
|
||||||
|
|
||||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractFunction* ptr) {
|
||||||
|
return ptr->getKind() == kMlir;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MLIRContext> context_;
|
std::unique_ptr<MLIRContext> context_;
|
||||||
@ -137,10 +211,10 @@ class MlirFunction : public AbstractFunction {
|
|||||||
FuncOp func_;
|
FuncOp func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirFunctionContext : public ExecutionContext {
|
class MlirFunctionContext : public TracingContext {
|
||||||
public:
|
public:
|
||||||
explicit MlirFunctionContext(const char* name)
|
explicit MlirFunctionContext(const char* name)
|
||||||
: ExecutionContext(kKind),
|
: TracingContext(kMlir),
|
||||||
context_(std::make_unique<MLIRContext>()),
|
context_(std::make_unique<MLIRContext>()),
|
||||||
builder_(context_.get()) {
|
builder_(context_.get()) {
|
||||||
// TODO(aminim) figure out the location story here
|
// TODO(aminim) figure out the location story here
|
||||||
@ -151,24 +225,27 @@ class MlirFunctionContext : public ExecutionContext {
|
|||||||
builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
|
builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
void Release() override { delete this; }
|
||||||
return new MlirAbstractOp(context_.get());
|
|
||||||
|
AbstractOperation* CreateOperation() override {
|
||||||
|
return new MlirAbstractOp(context_.get(), this);
|
||||||
}
|
}
|
||||||
|
Status AddParameter(tensorflow::DataType dtype,
|
||||||
|
TracingTensorHandle** handle) override;
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* abstract_op, int num_inputs,
|
Status Finalize(OutputList* outputs, AbstractFunction** f) override;
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
|
||||||
TF_Status* s) override;
|
|
||||||
|
|
||||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override;
|
Status RegisterFunction(AbstractFunction* func) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override;
|
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
|
||||||
s->status = tensorflow::errors::Unimplemented(
|
|
||||||
"Registering graph functions has not been implemented yet.");
|
"Registering graph functions has not been implemented yet.");
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr ExecutionContextKind kKind = kMlirContext;
|
Status RemoveFunction(const string& func) override {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"MlirFunctionContext::RemoveFunction has not been implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* CreateOperationFromState(const OperationState& state);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MLIRContext> context_;
|
std::unique_ptr<MLIRContext> context_;
|
||||||
@ -177,91 +254,88 @@ class MlirFunctionContext : public ExecutionContext {
|
|||||||
OwningModuleRef module_;
|
OwningModuleRef module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) {
|
Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
|
||||||
if (state_) {
|
if (state_) {
|
||||||
s->status = tensorflow::errors::FailedPrecondition(
|
return tensorflow::errors::FailedPrecondition(
|
||||||
"SetOpType called on already built op.");
|
"Reset called on already built op.");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
tf_op_type_ = op;
|
||||||
std::string name = "tf.";
|
std::string name = "tf.";
|
||||||
name += op_type;
|
name += op;
|
||||||
// TODO(aminim) figure out the location story here
|
// TODO(aminim) figure out the location story here
|
||||||
state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
|
state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype,
|
Status MlirAbstractOp::SetAttrType(const char* attr_name,
|
||||||
TF_Status* s) {
|
tensorflow::DataType dtype) {
|
||||||
if (!state_) {
|
if (!state_) {
|
||||||
s->status = tensorflow::errors::FailedPrecondition(
|
return Status(tensorflow::error::Code::FAILED_PRECONDITION,
|
||||||
"op_type must be specified before specifying attrs.");
|
"op_type must be specified before specifying attrs.");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
Type mlir_type;
|
Type mlir_type;
|
||||||
Builder builder(context_);
|
Builder builder(context_);
|
||||||
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
|
TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder, &mlir_type));
|
||||||
builder, &mlir_type);
|
|
||||||
if (!s->status.ok()) return;
|
|
||||||
attrs_[attr_name] = TypeAttr::get(mlir_type);
|
attrs_[attr_name] = TypeAttr::get(mlir_type);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) {
|
Status MlirAbstractOp::SetOpName(const char* const op_name) {
|
||||||
// TODO(aminim): should we use a location?
|
// TODO(aminim): should we use a location?
|
||||||
if (op_name_) {
|
if (op_name_) {
|
||||||
s->status = tensorflow::errors::FailedPrecondition(
|
return tensorflow::errors::FailedPrecondition(
|
||||||
"SetOpName called on already built op.");
|
"SetOpName called on already built op.");
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
op_name_ = op_name;
|
op_name_ = op_name;
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Type MlirAbstractOp::AddRef(Type type, TF_Status* s) {
|
Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
|
||||||
Type elt_type = getElementTypeOrSelf(type);
|
Type elt_type = getElementTypeOrSelf(type);
|
||||||
if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
|
if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Requested reference to a reference type");
|
"Requested reference to a reference type");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
elt_type = TensorFlowRefType::get(elt_type);
|
elt_type = TensorFlowRefType::get(elt_type);
|
||||||
if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
|
if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
return RankedTensorType::get(tensor_type.getShape(), elt_type);
|
*output_type = RankedTensorType::get(tensor_type.getShape(), elt_type);
|
||||||
}
|
}
|
||||||
return UnrankedTensorType::get(elt_type);
|
*output_type = UnrankedTensorType::get(elt_type);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
|
Status MlirAbstractOp::Create(ArrayRef<Value> operands,
|
||||||
|
OperationState** state) {
|
||||||
state_->operands = llvm::to_vector<4>(operands);
|
state_->operands = llvm::to_vector<4>(operands);
|
||||||
const tensorflow::OpDef* op_def;
|
const tensorflow::OpDef* op_def;
|
||||||
auto node_name = state_->name.getStringRef().drop_front(
|
auto node_name = state_->name.getStringRef().drop_front(
|
||||||
TensorFlowDialect::getDialectNamespace().size() + 1);
|
TensorFlowDialect::getDialectNamespace().size() + 1);
|
||||||
s->status =
|
TF_RETURN_IF_ERROR(
|
||||||
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def);
|
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def));
|
||||||
if (!s->status.ok()) return nullptr;
|
|
||||||
Builder builder(context_);
|
Builder builder(context_);
|
||||||
// Process operands according to the op_def and infer derived attributes.
|
// Process operands according to the op_def and infer derived attributes.
|
||||||
int current_operand = 0;
|
int current_operand = 0;
|
||||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
|
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
|
||||||
if (!input_arg.number_attr().empty()) {
|
if (!input_arg.number_attr().empty()) {
|
||||||
// TODO(b/156122856): we don't support variadic operands.
|
// TODO(b/156122856): we don't support variadic operands.
|
||||||
s->status = tensorflow::errors::Unimplemented(
|
return tensorflow::errors::Unimplemented(
|
||||||
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
|
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
|
||||||
return nullptr;
|
|
||||||
} else if (!input_arg.type_list_attr().empty()) {
|
} else if (!input_arg.type_list_attr().empty()) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
|
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
if (current_operand >= operands.size()) {
|
if (current_operand >= operands.size()) {
|
||||||
s->status = tensorflow::errors::InvalidArgument("Missing operand for '",
|
return tensorflow::errors::InvalidArgument("Missing operand for '",
|
||||||
input_arg.name(), "'");
|
input_arg.name(), "'");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
Type expected_type;
|
Type expected_type;
|
||||||
if (input_arg.type() != tensorflow::DT_INVALID) {
|
if (input_arg.type() != tensorflow::DT_INVALID) {
|
||||||
s->status =
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type);
|
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type));
|
||||||
if (!s->status.ok()) return nullptr;
|
Type output_type;
|
||||||
if (input_arg.is_ref()) expected_type = AddRef(expected_type, s);
|
if (input_arg.is_ref())
|
||||||
if (!s->status.ok()) return nullptr;
|
TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
|
||||||
|
expected_type = output_type;
|
||||||
} else {
|
} else {
|
||||||
expected_type = operands[current_operand].getType();
|
expected_type = operands[current_operand].getType();
|
||||||
}
|
}
|
||||||
@ -277,17 +351,15 @@ OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
|
|||||||
// Same type repeated "repeats" times.
|
// Same type repeated "repeats" times.
|
||||||
Attribute repeats_attr = attrs_[output_arg.number_attr()];
|
Attribute repeats_attr = attrs_[output_arg.number_attr()];
|
||||||
if (!repeats_attr) {
|
if (!repeats_attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Missing attribute '", output_arg.number_attr(),
|
"Missing attribute '", output_arg.number_attr(),
|
||||||
"' required for output list '", output_arg.name(), "'");
|
"' required for output list '", output_arg.name(), "'");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
if (!repeats_attr.isa<IntegerAttr>()) {
|
if (!repeats_attr.isa<IntegerAttr>()) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Attribute '", output_arg.number_attr(),
|
"Attribute '", output_arg.number_attr(),
|
||||||
"' required for output list '", output_arg.name(),
|
"' required for output list '", output_arg.name(),
|
||||||
"' isn't an integer");
|
"' isn't an integer");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
|
int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
@ -295,102 +367,186 @@ OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
|
|||||||
// Same type repeated "repeats" times.
|
// Same type repeated "repeats" times.
|
||||||
Attribute attr = attrs_[output_arg.type_attr()];
|
Attribute attr = attrs_[output_arg.type_attr()];
|
||||||
if (!attr) {
|
if (!attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Missing attribute '", output_arg.type_attr(),
|
"Missing attribute '", output_arg.type_attr(),
|
||||||
"' required for output '", output_arg.name(), "'");
|
"' required for output '", output_arg.name(), "'");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||||
if (!type_attr) {
|
if (!type_attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Attribute '", output_arg.type_attr(), "' required for output '",
|
"Attribute '", output_arg.type_attr(), "' required for output '",
|
||||||
output_arg.name(), "' isn't a type attribute");
|
output_arg.name(), "' isn't a type attribute");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < repeats; ++i)
|
for (int i = 0; i < repeats; ++i)
|
||||||
state_->types.push_back(type_attr.getType());
|
state_->types.push_back(type_attr.getType());
|
||||||
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
||||||
for (int i = 0; i < repeats; ++i) {
|
for (int i = 0; i < repeats; ++i) {
|
||||||
Type type;
|
Type type;
|
||||||
s->status =
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertDataTypeToTensor(output_arg.type(), builder, &type);
|
ConvertDataTypeToTensor(output_arg.type(), builder, &type));
|
||||||
if (!s->status.ok()) return nullptr;
|
|
||||||
state_->types.push_back(type);
|
state_->types.push_back(type);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Missing type or type_attr field in ",
|
"Missing type or type_attr field in ",
|
||||||
output_arg.ShortDebugString());
|
output_arg.ShortDebugString());
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
} else if (!output_arg.type_attr().empty()) {
|
} else if (!output_arg.type_attr().empty()) {
|
||||||
Attribute attr = attrs_[output_arg.type_attr()];
|
Attribute attr = attrs_[output_arg.type_attr()];
|
||||||
if (!attr) {
|
if (!attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Missing attribute '", output_arg.type_attr(),
|
"Missing attribute '", output_arg.type_attr(),
|
||||||
"' required for output '", output_arg.name(), "'");
|
"' required for output '", output_arg.name(), "'");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||||
if (!type_attr) {
|
if (!type_attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Attribute '", output_arg.type_attr(), "' required for output '",
|
"Attribute '", output_arg.type_attr(), "' required for output '",
|
||||||
output_arg.name(), "' isn't a type attribute");
|
output_arg.name(), "' isn't a type attribute");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
state_->types.push_back(type_attr.getValue());
|
state_->types.push_back(type_attr.getValue());
|
||||||
} else if (!output_arg.type_list_attr().empty()) {
|
} else if (!output_arg.type_list_attr().empty()) {
|
||||||
// This is pointing to an attribute which is an array of types.
|
// This is pointing to an attribute which is an array of types.
|
||||||
Attribute attr = attrs_[output_arg.type_list_attr()];
|
Attribute attr = attrs_[output_arg.type_list_attr()];
|
||||||
if (!attr) {
|
if (!attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Missing attribute '", output_arg.type_list_attr(),
|
"Missing attribute '", output_arg.type_list_attr(),
|
||||||
"' required for output '", output_arg.name(), "'");
|
"' required for output '", output_arg.name(), "'");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
|
ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
|
||||||
if (!array_attr) {
|
if (!array_attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Attribute '", output_arg.type_list_attr(),
|
"Attribute '", output_arg.type_list_attr(),
|
||||||
"' required for output '", output_arg.name(),
|
"' required for output '", output_arg.name(),
|
||||||
"' isn't an array attribute");
|
"' isn't an array attribute");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
for (Attribute attr : array_attr) {
|
for (Attribute attr : array_attr) {
|
||||||
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||||
if (!type_attr) {
|
if (!type_attr) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Array Attribute '", output_arg.type_list_attr(),
|
"Array Attribute '", output_arg.type_list_attr(),
|
||||||
"' required for output '", output_arg.name(),
|
"' required for output '", output_arg.name(),
|
||||||
"' has a non-Type element");
|
"' has a non-Type element");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
state_->types.push_back(type_attr.getValue());
|
state_->types.push_back(type_attr.getValue());
|
||||||
}
|
}
|
||||||
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
||||||
Type type;
|
Type type;
|
||||||
Builder builder(context_);
|
Builder builder(context_);
|
||||||
s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type);
|
TF_RETURN_IF_ERROR(
|
||||||
if (!s->status.ok()) return nullptr;
|
ConvertDataTypeToTensor(output_arg.type(), builder, &type));
|
||||||
state_->types.push_back(type);
|
state_->types.push_back(type);
|
||||||
} else {
|
} else {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument("No type fields in ",
|
||||||
"No type fields in ", output_arg.ShortDebugString());
|
output_arg.ShortDebugString());
|
||||||
if (!s->status.ok()) return nullptr;
|
|
||||||
}
|
}
|
||||||
if (output_arg.is_ref()) {
|
if (output_arg.is_ref()) {
|
||||||
// For all types that were added by this function call, make them refs.
|
// For all types that were added by this function call, make them refs.
|
||||||
for (Type& type : llvm::make_range(&state_->types[original_size],
|
for (Type& type : llvm::make_range(&state_->types[original_size],
|
||||||
state_->types.end())) {
|
state_->types.end())) {
|
||||||
type = AddRef(type, s);
|
Type output_type;
|
||||||
if (!s->status.ok()) return nullptr;
|
TF_RETURN_IF_ERROR(AddRef(type, &output_type));
|
||||||
|
type = output_type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return state_.get();
|
*state = state_.get();
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Function* MlirFunction::GetTfFunction(TF_Status* s) {
|
const string& MlirAbstractOp::Name() const { return tf_op_type_; }
|
||||||
|
|
||||||
|
const string& MlirAbstractOp::DeviceName() const { return device_name_; }
|
||||||
|
|
||||||
|
Status MlirAbstractOp::SetDeviceName(const char* name) {
|
||||||
|
device_name_ = name;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MlirAbstractOp::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"AddInputList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrString has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrInt has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFloat has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrBool has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||||
|
const int num_dims) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrShape has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrFunction(const char* attr_name,
|
||||||
|
const AbstractOperation* value) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFunction has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name,
|
||||||
|
const char* value, size_t length) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFunctionName has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrTensor(const char* attr_name,
|
||||||
|
AbstractTensorInterface* tensor) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrTensor has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrStringList(const char* attr_name,
|
||||||
|
const void* const* values,
|
||||||
|
const size_t* lengths,
|
||||||
|
int num_values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrStringList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrFloatList(const char* attr_name,
|
||||||
|
const float* values, int num_values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFloatList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrIntList(const char* attr_name,
|
||||||
|
const int64_t* values, int num_values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrIntList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrTypeList(const char* attr_name,
|
||||||
|
const tensorflow::DataType* values,
|
||||||
|
int num_values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrTypeList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrBoolList(const char* attr_name,
|
||||||
|
const unsigned char* values,
|
||||||
|
int num_values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrBoolList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrShapeList(const char* attr_name,
|
||||||
|
const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrShapeList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
Status MlirAbstractOp::SetAttrFunctionList(
|
||||||
|
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"SetAttrFunctionList has not been implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
|
||||||
PassManager pm(func_.getContext());
|
PassManager pm(func_.getContext());
|
||||||
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
||||||
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||||
@ -400,75 +556,59 @@ TF_Function* MlirFunction::GetTfFunction(TF_Status* s) {
|
|||||||
StatusScopedDiagnosticHandler diag_handler(func_.getContext());
|
StatusScopedDiagnosticHandler diag_handler(func_.getContext());
|
||||||
LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
|
LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
|
||||||
(void)result;
|
(void)result;
|
||||||
s->status = diag_handler.ConsumeStatus();
|
TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
|
||||||
if (!s->status.ok()) return nullptr;
|
|
||||||
|
|
||||||
tensorflow::GraphExportConfig configs;
|
tensorflow::GraphExportConfig configs;
|
||||||
std::unique_ptr<TF_Function> tf_function(new TF_Function);
|
*f = new tensorflow::FunctionDef();
|
||||||
s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs,
|
return ConvertMlirFunctionToFunctionLibraryDef(func_, configs, *f);
|
||||||
&tf_function->fdef);
|
|
||||||
return tf_function.release();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op,
|
Status MlirAbstractOp::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
int num_inputs,
|
int* num_retvals) {
|
||||||
AbstractTensor* const* inputs,
|
OperationState* state;
|
||||||
OutputList* o, TF_Status* s) {
|
TF_RETURN_IF_ERROR(Create(operands_, &state));
|
||||||
auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op);
|
Operation* op = function_context_->CreateOperationFromState(*state);
|
||||||
if (mlir_op == nullptr) {
|
*num_retvals = op->getNumResults();
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
for (int i = 0; i < *num_retvals; i++)
|
||||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
retvals[i] = new MlirTensor(op->getResult(i));
|
||||||
return;
|
return Status::OK();
|
||||||
}
|
|
||||||
SmallVector<Value, 8> operands;
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto* operand = dyncast<MlirTensor>(inputs[i]);
|
|
||||||
if (!operand) {
|
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Capturing eager tensors is not supported yet.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (operand->getValue().getContext() != context_.get()) {
|
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Capturing tensors from other context is not supported.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
operands.push_back(operand->getValue());
|
|
||||||
}
|
|
||||||
OperationState* state = mlir_op->Create(operands, s);
|
|
||||||
if (!s->status.ok() || !state) return;
|
|
||||||
Operation* op = builder_.createOperation(*state);
|
|
||||||
int num_results = op->getNumResults();
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_results);
|
|
||||||
for (Value result : op->getResults())
|
|
||||||
o->outputs.push_back(new MlirTensor(result));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype,
|
Operation* MlirFunctionContext::CreateOperationFromState(
|
||||||
TF_Status* s) {
|
const OperationState& state) {
|
||||||
|
return builder_.createOperation(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype,
|
||||||
|
TracingTensorHandle** handle) {
|
||||||
Type type;
|
Type type;
|
||||||
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
|
TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
|
||||||
builder_, &type);
|
*handle = new MlirTensor(func_.getBody().front().addArgument(type));
|
||||||
if (!s->status.ok()) return nullptr;
|
return Status::OK();
|
||||||
return new MlirTensor(func_.getBody().front().addArgument(type));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs,
|
Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
|
||||||
TF_Status* s) {
|
auto* operand = dyn_cast<MlirTensor>(input);
|
||||||
|
if (!operand) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Unable to cast input to MlirTensor");
|
||||||
|
}
|
||||||
|
operands_.push_back(operand->getValue());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status MlirFunctionContext::Finalize(OutputList* outputs,
|
||||||
|
AbstractFunction** f) {
|
||||||
Block& body = func_.getBody().front();
|
Block& body = func_.getBody().front();
|
||||||
SmallVector<Value, 8> ret_operands;
|
SmallVector<Value, 8> ret_operands;
|
||||||
for (AbstractTensor* output : outputs->outputs) {
|
for (auto* output : outputs->outputs) {
|
||||||
auto* operand = dyncast<MlirTensor>(output);
|
auto* operand = dyn_cast<MlirTensor>(output);
|
||||||
if (!operand) {
|
if (!operand) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Capturing eager tensors is not supported yet.");
|
"Capturing eager tensors is not supported yet.");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
if (operand->getValue().getContext() != context_.get()) {
|
if (operand->getValue().getContext() != context_.get()) {
|
||||||
s->status = tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Capturing tensors from other context is not supported.");
|
"Capturing tensors from other context is not supported.");
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
ret_operands.push_back(operand->getValue());
|
ret_operands.push_back(operand->getValue());
|
||||||
}
|
}
|
||||||
@ -478,16 +618,17 @@ AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs,
|
|||||||
auto result_types =
|
auto result_types =
|
||||||
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
|
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
|
||||||
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
|
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
|
||||||
return new MlirFunction(std::move(context_), std::move(module_), func_);
|
*f = new MlirFunction(std::move(context_), std::move(module_), func_);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
|
TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
return new MlirFunctionContext(fn_name);
|
return new MlirFunctionContext(fn_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // namespace
|
||||||
} // end namespace TF
|
} // namespace TF
|
||||||
} // end namespace mlir
|
} // namespace mlir
|
||||||
|
@ -15,10 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
|
|
||||||
using tensorflow::internal::ExecutionContext;
|
using tensorflow::tracing::TracingContext;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s);
|
TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
Loading…
Reference in New Issue
Block a user