Improve documentation for the unified C API and perform some minor cleanups
- Document most of the APIs, - move the C++ implementation in the tensorflow::internal namespace, - separate more cleanly the C API entry points, - Drop the TF_ prefix from the C++ classes where it was missed during refactoring. - Use IWYU to cleanup the includes - Rename c_api_unified_experimental_private.h into c_api_unified_experimental_internal.h for align with the convention in the directory. PiperOrigin-RevId: 308699592 Change-Id: Id7198f73a63c400ce94bdeefeef010f441622a88
This commit is contained in:
parent
c6eaf562e7
commit
ea74b01d80
@ -369,7 +369,7 @@ tf_cuda_library(
|
|||||||
"c_api_unified_experimental.cc",
|
"c_api_unified_experimental.cc",
|
||||||
"c_api_unified_experimental_eager.cc",
|
"c_api_unified_experimental_eager.cc",
|
||||||
"c_api_unified_experimental_graph.cc",
|
"c_api_unified_experimental_graph.cc",
|
||||||
"c_api_unified_experimental_private.h",
|
"c_api_unified_experimental_internal.h",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
@ -466,6 +466,7 @@ tf_cuda_cc_test(
|
|||||||
":c_api",
|
":c_api",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":c_api_test_util",
|
":c_api_test_util",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_test_util",
|
"//tensorflow/c:c_test_util",
|
||||||
"//tensorflow/cc/profiler",
|
"//tensorflow/cc/profiler",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -13,24 +13,28 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "absl/types/variant.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include <vector>
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
|
||||||
#include "tensorflow/core/platform/casts.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
using tensorflow::internal::OutputList;
|
using tensorflow::internal::OutputList;
|
||||||
using tensorflow::internal::unwrap;
|
using tensorflow::internal::unwrap;
|
||||||
using tensorflow::internal::wrap;
|
|
||||||
|
// =============================================================================
|
||||||
|
// Public C API entry points
|
||||||
|
//
|
||||||
|
// These are only the generic entry points for the C API. This file does not
|
||||||
|
// have any visibility into the graph/eager implementation and is only providing
|
||||||
|
// C bindings to the abstract classes defined in the
|
||||||
|
// c_api_unified_experimental_internal.h header.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
@ -34,23 +35,34 @@ extern "C" {
|
|||||||
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
||||||
// of gradient tapes, etc.
|
// of gradient tapes, etc.
|
||||||
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
||||||
|
|
||||||
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
||||||
// type of eager and graph tensors.
|
// type of eager and graph tensors. It is also the result of executing an
|
||||||
|
// operation.
|
||||||
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||||
|
|
||||||
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
||||||
// could contain the op type and other attributes.
|
// could contain the op type and other attributes.
|
||||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||||
|
|
||||||
|
// Stores a function representation that can be used for execution or for
|
||||||
|
// setting functional attributes of other composite ops e.g. control flow.
|
||||||
|
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||||
|
|
||||||
|
// Creates a context for tracing the execution of operations into a function.
|
||||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
||||||
|
|
||||||
|
// Creates a context for eager execution of operations.
|
||||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||||
|
|
||||||
|
// Create an operation suitable to use with the provided context. The operation
|
||||||
|
// requires its type (e.g. "AddV2") to be set independently.
|
||||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||||
|
|
||||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
|
||||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||||
// `op_type` must outlive `op`.
|
// `op_type` must outlive `op`.
|
||||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||||
@ -62,9 +74,16 @@ void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
|||||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||||
TF_DataType value, TF_Status* s);
|
TF_DataType value, TF_Status* s);
|
||||||
|
|
||||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||||
|
|
||||||
|
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||||
|
// an operation.
|
||||||
|
// It just lets us not specify the number of outputs of an operation
|
||||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||||
// it allows for generic code.
|
// it allows for generic code.
|
||||||
|
// TODO(aminim): the description above isn't clear with respect to
|
||||||
|
// TF_OutputListNumOutputs and the current eager implementation which requires
|
||||||
|
// the number of outputs to be set by the client.
|
||||||
typedef struct TF_OutputList TF_OutputList;
|
typedef struct TF_OutputList TF_OutputList;
|
||||||
TF_OutputList* TF_NewOutputList();
|
TF_OutputList* TF_NewOutputList();
|
||||||
void TF_DeleteOutputList(TF_OutputList* o);
|
void TF_DeleteOutputList(TF_OutputList* o);
|
||||||
@ -72,27 +91,32 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
|||||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||||
|
|
||||||
// Stores a function representation that can be used for execution or for
|
|
||||||
// setting functional attributes of other composite ops e.g. control flow.
|
|
||||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
|
||||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
|
||||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
|
||||||
const TF_AbstractTensor* inputs, int num_outputs,
|
|
||||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
|
||||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
|
||||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
|
||||||
TF_AbstractFunction*, TF_Status*);
|
|
||||||
|
|
||||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||||
// capture some inputs and then add a node in the graph, and after
|
// capture some inputs and then add a node in the graph. The output tensors are
|
||||||
// execution/node creation it'll go and record things that happened in any tape
|
// returned through the provided TF_OutputList.
|
||||||
// which happens to be active.
|
// Any active tape will observe the effects of this execution.
|
||||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||||
TF_ExecutionContext* ctx, TF_Status* s);
|
TF_ExecutionContext* ctx, TF_Status* s);
|
||||||
|
|
||||||
|
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||||
|
// context. The returned TF_GraphToFunction must be deleted by the client.
|
||||||
|
// TODO(aminim): clarify the contract on the state of the context after this
|
||||||
|
// call.
|
||||||
|
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||||
|
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||||
|
const TF_AbstractTensor* inputs, int num_outputs,
|
||||||
|
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||||
|
|
||||||
|
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||||
|
|
||||||
|
// Register the function with the given context. This is particularly useful for
|
||||||
|
// making a function available to an eager context.
|
||||||
|
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||||
|
TF_AbstractFunction*, TF_Status*);
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// APIs specific to Eager and graph modes
|
// APIs specific to Eager modes
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||||
|
@ -13,32 +13,24 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "absl/types/variant.h"
|
#include <vector>
|
||||||
#include "tensorflow/c/c_api.h"
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
using tensorflow::internal::AbstractFunction;
|
|
||||||
using tensorflow::internal::AbstractOp;
|
|
||||||
using tensorflow::internal::AbstractTensor;
|
|
||||||
using tensorflow::internal::dynamic_cast_helper;
|
|
||||||
using tensorflow::internal::ExecutionContext;
|
|
||||||
using tensorflow::internal::OutputList;
|
|
||||||
using tensorflow::internal::unwrap;
|
|
||||||
using tensorflow::internal::wrap;
|
|
||||||
|
|
||||||
class TF_EagerContext;
|
namespace tensorflow {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// Simple wrapper over a TFE_TensorHandle
|
||||||
struct EagerTensor : public AbstractTensor {
|
struct EagerTensor : public AbstractTensor {
|
||||||
TFE_TensorHandle* t = nullptr;
|
TFE_TensorHandle* t = nullptr;
|
||||||
EagerTensor() : AbstractTensor(kKind) {}
|
EagerTensor() : AbstractTensor(kKind) {}
|
||||||
@ -47,9 +39,10 @@ struct EagerTensor : public AbstractTensor {
|
|||||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TF_EagerOp : public AbstractOp {
|
// Simple wrapper over a TFE_Op
|
||||||
|
class EagerOp : public AbstractOp {
|
||||||
public:
|
public:
|
||||||
explicit TF_EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
||||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||||
}
|
}
|
||||||
@ -66,18 +59,19 @@ class TF_EagerOp : public AbstractOp {
|
|||||||
TFE_OpSetAttrType(op_, attr_name, value);
|
TFE_OpSetAttrType(op_, attr_name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
~TF_EagerOp() override { TFE_DeleteOp(op_); }
|
~EagerOp() override { TFE_DeleteOp(op_); }
|
||||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
static constexpr AbstractOpKind kKind = kEagerOp;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class TF_EagerContext; // For access to op_.
|
friend class EagerContext; // For access to op_.
|
||||||
TFE_Op* op_ = nullptr;
|
TFE_Op* op_ = nullptr;
|
||||||
TFE_Context* ctx_;
|
TFE_Context* ctx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TF_EagerContext : public ExecutionContext {
|
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
|
||||||
|
class EagerContext : public ExecutionContext {
|
||||||
public:
|
public:
|
||||||
TF_EagerContext() : ExecutionContext(kKind) {}
|
EagerContext() : ExecutionContext(kKind) {}
|
||||||
|
|
||||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||||
eager_ctx_ = TFE_NewContext(options, status);
|
eager_ctx_ = TFE_NewContext(options, status);
|
||||||
@ -85,13 +79,13 @@ class TF_EagerContext : public ExecutionContext {
|
|||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
AbstractOp* CreateOperation() override {
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
return new TF_EagerOp(eager_ctx_);
|
return new EagerOp(eager_ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
AbstractTensor* const* inputs, OutputList* o,
|
||||||
TF_Status* s) override {
|
TF_Status* s) override {
|
||||||
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
auto* eager_op = dyncast<EagerOp>(op);
|
||||||
if (eager_op == nullptr) {
|
if (eager_op == nullptr) {
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
"Unable to cast AbstractOp to TF_EagerOp.");
|
||||||
@ -100,7 +94,7 @@ class TF_EagerContext : public ExecutionContext {
|
|||||||
auto* tfe_op = eager_op->op_;
|
auto* tfe_op = eager_op->op_;
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
auto* eager_tensor = dynamic_cast_helper<const EagerTensor>(inputs[i]);
|
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
|
||||||
if (!eager_tensor) {
|
if (!eager_tensor) {
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||||
return;
|
return;
|
||||||
@ -137,34 +131,45 @@ class TF_EagerContext : public ExecutionContext {
|
|||||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||||
|
|
||||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
static constexpr ExecutionContextKind kKind = kEagerContext;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend TFE_Context* TF_ExecutionContextGetTFEContext(
|
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
|
||||||
TF_ExecutionContext* ctx);
|
TF_ExecutionContext* ctx);
|
||||||
TFE_Context* eager_ctx_;
|
TFE_Context* eager_ctx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Public C API entry points
|
||||||
|
// These are only the entry points specific to the Eager API.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
using tensorflow::internal::dyncast;
|
||||||
|
using tensorflow::internal::unwrap;
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
auto* ctx = new TF_EagerContext();
|
auto* ctx = new tensorflow::internal::EagerContext();
|
||||||
ctx->Build(options, s);
|
ctx->Build(options, s);
|
||||||
return wrap(ctx);
|
return wrap(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
return wrap(new EagerTensor(t));
|
return wrap(new tensorflow::internal::EagerTensor(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
auto* eager_tensor = dynamic_cast_helper<EagerTensor>(unwrap(at));
|
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
|
||||||
if (!eager_tensor) {
|
if (!eager_tensor) {
|
||||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
|
||||||
reinterpret_cast<uintptr_t>(at));
|
reinterpret_cast<uintptr_t>(at));
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -172,5 +177,7 @@ TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||||
return dynamic_cast_helper<TF_EagerContext>(unwrap(ctx))->eager_ctx_;
|
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
|
||||||
|
if (!eager_ctx) return nullptr;
|
||||||
|
return eager_ctx->eager_ctx_;
|
||||||
}
|
}
|
||||||
|
@ -13,49 +13,45 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "absl/types/variant.h"
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
|
||||||
#include "tensorflow/core/platform/casts.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
using tensorflow::internal::AbstractFunction;
|
|
||||||
using tensorflow::internal::AbstractOp;
|
|
||||||
using tensorflow::internal::AbstractTensor;
|
|
||||||
using tensorflow::internal::dynamic_cast_helper;
|
|
||||||
using tensorflow::internal::ExecutionContext;
|
|
||||||
using tensorflow::internal::OutputList;
|
|
||||||
using tensorflow::internal::unwrap;
|
|
||||||
using tensorflow::internal::wrap;
|
|
||||||
|
|
||||||
class TF_GraphContext;
|
namespace tensorflow {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
class GraphContext;
|
||||||
|
|
||||||
|
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||||
|
// into the list of outputs for the operation.
|
||||||
struct GraphTensor : public AbstractTensor {
|
struct GraphTensor : public AbstractTensor {
|
||||||
TF_Output output{};
|
TF_Output output{};
|
||||||
TF_GraphContext* ctx = nullptr;
|
GraphContext* ctx = nullptr;
|
||||||
GraphTensor() : AbstractTensor(kKind) {}
|
GraphTensor() : AbstractTensor(kKind) {}
|
||||||
GraphTensor(TF_Output output, TF_GraphContext* ctx)
|
GraphTensor(TF_Output output, GraphContext* ctx)
|
||||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
||||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TF_GraphOp : public AbstractOp {
|
// GraphOp wraps and populate a TF_OperationDescription.
|
||||||
|
class GraphOp : public AbstractOp {
|
||||||
public:
|
public:
|
||||||
explicit TF_GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
||||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||||
if (op_) {
|
if (op_) {
|
||||||
TF_SetStatus(
|
TF_SetStatus(
|
||||||
s, TF_FAILED_PRECONDITION,
|
s, TF_FAILED_PRECONDITION,
|
||||||
absl::StrCat("SetOpType called on already built op.").c_str());
|
strings::StrCat("SetOpType called on already built op.").c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (op_name_ != nullptr) {
|
if (op_name_ != nullptr) {
|
||||||
@ -69,7 +65,7 @@ class TF_GraphOp : public AbstractOp {
|
|||||||
if (op_) {
|
if (op_) {
|
||||||
TF_SetStatus(
|
TF_SetStatus(
|
||||||
s, TF_FAILED_PRECONDITION,
|
s, TF_FAILED_PRECONDITION,
|
||||||
absl::StrCat("SetOpName called on already built op.").c_str());
|
strings::StrCat("SetOpName called on already built op.").c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (op_type_ != nullptr) {
|
if (op_type_ != nullptr) {
|
||||||
@ -89,12 +85,12 @@ class TF_GraphOp : public AbstractOp {
|
|||||||
}
|
}
|
||||||
TF_SetAttrType(op_.get(), attr_name, value);
|
TF_SetAttrType(op_.get(), attr_name, value);
|
||||||
}
|
}
|
||||||
~TF_GraphOp() override {}
|
~GraphOp() override {}
|
||||||
|
|
||||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
static constexpr AbstractOpKind kKind = kGraphOp;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class TF_GraphContext; // For access to op_.
|
friend class GraphContext; // For access to op_.
|
||||||
TF_Graph* g_;
|
TF_Graph* g_;
|
||||||
std::unique_ptr<TF_OperationDescription> op_;
|
std::unique_ptr<TF_OperationDescription> op_;
|
||||||
// Hold `op_type` and `op_name` till both are available since we need both
|
// Hold `op_type` and `op_name` till both are available since we need both
|
||||||
@ -103,6 +99,7 @@ class TF_GraphOp : public AbstractOp {
|
|||||||
const char* op_name_ = nullptr;
|
const char* op_name_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GraphFunction is a thin wrapper over a TF_Function.
|
||||||
struct GraphFunction : public AbstractFunction {
|
struct GraphFunction : public AbstractFunction {
|
||||||
TF_Function* func = nullptr;
|
TF_Function* func = nullptr;
|
||||||
GraphFunction() : AbstractFunction(kKind) {}
|
GraphFunction() : AbstractFunction(kKind) {}
|
||||||
@ -117,20 +114,22 @@ struct GraphFunction : public AbstractFunction {
|
|||||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TF_GraphContext : public ExecutionContext {
|
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
||||||
|
// adding them to the graph.
|
||||||
|
class GraphContext : public ExecutionContext {
|
||||||
public:
|
public:
|
||||||
TF_GraphContext()
|
GraphContext()
|
||||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
AbstractOp* CreateOperation() override {
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
return new TF_GraphOp(graph_.get());
|
return new GraphOp(graph_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
AbstractTensor* const* inputs, OutputList* o,
|
||||||
TF_Status* s) override {
|
TF_Status* s) override {
|
||||||
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
auto* graph_op = dyncast<GraphOp>(op);
|
||||||
if (graph_op == nullptr) {
|
if (graph_op == nullptr) {
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||||
@ -138,7 +137,7 @@ class TF_GraphContext : public ExecutionContext {
|
|||||||
}
|
}
|
||||||
auto* tf_opdesc = graph_op->op_.release();
|
auto* tf_opdesc = graph_op->op_.release();
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
auto* graph_tensor = dynamic_cast_helper<GraphTensor>(inputs[i]);
|
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||||
if (!graph_tensor) {
|
if (!graph_tensor) {
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
"Capturing eager tensors is not supported yet.");
|
"Capturing eager tensors is not supported yet.");
|
||||||
@ -190,7 +189,7 @@ class TF_GraphContext : public ExecutionContext {
|
|||||||
"Registering graph functions has not been implemented yet.");
|
"Registering graph functions has not been implemented yet.");
|
||||||
}
|
}
|
||||||
|
|
||||||
~TF_GraphContext() override {}
|
~GraphContext() override {}
|
||||||
|
|
||||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
static constexpr ExecutionContextKind kKind = kGraphContext;
|
||||||
|
|
||||||
@ -198,26 +197,23 @@ class TF_GraphContext : public ExecutionContext {
|
|||||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
// Helper that converts the graph currently held in the context into a function.
|
||||||
return wrap(new TF_GraphContext());
|
static AbstractFunction* ExecutionContextToFunction(
|
||||||
}
|
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||||
|
const AbstractTensor* inputs, int num_outputs,
|
||||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
const AbstractTensor* outputs, TF_Status* status) {
|
||||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
||||||
const TF_AbstractTensor* inputs, int num_outputs,
|
|
||||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
|
||||||
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(unwrap(fn_body));
|
|
||||||
if (graph_ctx == nullptr) {
|
if (graph_ctx == nullptr) {
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
"fn_body is not a TF_GraphContext.");
|
"fn_body is not a TF_GraphContext.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto* graph_inputs = dynamic_cast_helper<const GraphTensor>(unwrap(inputs));
|
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
||||||
if (!graph_inputs) {
|
if (!graph_inputs) {
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto* graph_outputs = dynamic_cast_helper<const GraphTensor>(unwrap(outputs));
|
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
||||||
if (!graph_outputs) {
|
if (!graph_outputs) {
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -225,5 +221,28 @@ TF_AbstractFunction* TF_ExecutionContextToFunction(
|
|||||||
GraphFunction* func = new GraphFunction;
|
GraphFunction* func = new GraphFunction;
|
||||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||||
num_outputs, graph_outputs, status);
|
num_outputs, graph_outputs, status);
|
||||||
return wrap(func);
|
return func;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Public C API entry points
|
||||||
|
// These are only the entry points specific to the Graph API.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
using tensorflow::internal::unwrap;
|
||||||
|
|
||||||
|
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||||
|
return wrap(new tensorflow::internal::GraphContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||||
|
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||||
|
const TF_AbstractTensor* inputs, int num_outputs,
|
||||||
|
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||||
|
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
||||||
|
unwrap(inputs), num_outputs,
|
||||||
|
unwrap(outputs), status));
|
||||||
}
|
}
|
||||||
|
184
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
184
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Implementation detail for the unified execution APIs for Eager and tracing
|
||||||
|
// backends (graph/MLIR).
|
||||||
|
//
|
||||||
|
// This defines a set of abstract classes that are intended to provide the
|
||||||
|
// functionality of the opaque C types exposed in the public APIs defined in the
|
||||||
|
// `c_api_unified_experimental.h` header.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// We can't depend on C++ rtti, but we still want to be able to have a safe
|
||||||
|
// dynamic_cast to provide diagnostics to the user when the API is misused.
|
||||||
|
// Instead we model RTTI by listing all the possible subclasses for each
|
||||||
|
// abstract base. Each subclass initializes the base class with the right
|
||||||
|
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
|
||||||
|
// utility.
|
||||||
|
template <typename T, typename S>
|
||||||
|
T* dyncast(S source) {
|
||||||
|
if (source->getKind() != T::kKind) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensorflow::down_cast<T*>(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents either an EagerTensor or a GraphTensor.
|
||||||
|
// This base class does not expose any public methods other than to distinguish
|
||||||
|
// which subclass it actually is. The user is responsible to use the right
|
||||||
|
// type of AbstractTensor in their context (do not pass an EagerTensor to a
|
||||||
|
// GraphContext and vice-versa).
|
||||||
|
class AbstractTensor {
|
||||||
|
protected:
|
||||||
|
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||||
|
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Returns which subclass is this instance of.
|
||||||
|
AbstractTensorKind getKind() const { return kind_; }
|
||||||
|
virtual ~AbstractTensor() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const AbstractTensorKind kind_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Represents the results of the execution of an operation.
|
||||||
|
struct OutputList {
|
||||||
|
std::vector<AbstractTensor*> outputs;
|
||||||
|
int expected_num_outputs = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Holds the result of tracing a function.
|
||||||
|
class AbstractFunction {
|
||||||
|
protected:
|
||||||
|
enum AbstractFunctionKind { kGraphFunc };
|
||||||
|
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Returns which subclass is this instance of.
|
||||||
|
AbstractFunctionKind getKind() const { return kind_; }
|
||||||
|
virtual ~AbstractFunction() = default;
|
||||||
|
|
||||||
|
// Temporary API till we figure the right abstraction for AbstractFunction.
|
||||||
|
// At the moment both Eager and Graph needs access to a "TF_Function" object.
|
||||||
|
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const AbstractFunctionKind kind_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An abstract operation describes an operation by its type, name, and
|
||||||
|
// attributes. It can be "executed" by the context with some input tensors.
|
||||||
|
// It is allowed to reusing the same abstract operation for multiple execution
|
||||||
|
// on a given context, with the same or different input tensors.
|
||||||
|
class AbstractOp {
|
||||||
|
protected:
|
||||||
|
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||||
|
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Returns which subclass is this instance of.
|
||||||
|
AbstractOpKind getKind() const { return kind_; }
|
||||||
|
virtual ~AbstractOp() = default;
|
||||||
|
|
||||||
|
// Sets the type of the operation (for example `AddV2`).
|
||||||
|
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
|
||||||
|
|
||||||
|
// Sets the name of the operation: this is an optional identifier that is
|
||||||
|
// not intended to carry semantics and preserved/propagated without
|
||||||
|
// guarantees.
|
||||||
|
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
|
||||||
|
|
||||||
|
// Add a `TypeAttribute` on the operation.
|
||||||
|
virtual void SetAttrType(const char* attr_name, TF_DataType value,
|
||||||
|
TF_Status* s) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const AbstractOpKind kind_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This holds the context for the execution: dispatching operations either to an
|
||||||
|
// eager implementation or to a graph implementation.
|
||||||
|
struct ExecutionContext {
|
||||||
|
protected:
|
||||||
|
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||||
|
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Returns which subclass is this instance of.
|
||||||
|
ExecutionContextKind getKind() const { return k; }
|
||||||
|
virtual ~ExecutionContext() = default;
|
||||||
|
|
||||||
|
// Executes the operation on the provided inputs and populate the OutputList
|
||||||
|
// with the results. The input tensors must match the current context.
|
||||||
|
// The effect of "executing" an operation depends on the context: in an Eager
|
||||||
|
// context it will dispatch it to the runtime for execution, while in a
|
||||||
|
// tracing context it will add the operation to the current function.
|
||||||
|
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||||
|
AbstractTensor* const* inputs, OutputList* o,
|
||||||
|
TF_Status* s) = 0;
|
||||||
|
|
||||||
|
// Creates an empty AbstractOperation suitable to use with this context.
|
||||||
|
virtual AbstractOp* CreateOperation() = 0;
|
||||||
|
|
||||||
|
// Registers a functions with this context, after this the function is
|
||||||
|
// available to be called/referenced by its name in this context.
|
||||||
|
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ExecutionContextKind k;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||||
|
// C++ implementation, and back.
|
||||||
|
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||||
|
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
||||||
|
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
||||||
|
} \
|
||||||
|
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
||||||
|
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
||||||
|
} \
|
||||||
|
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
||||||
|
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
||||||
|
} \
|
||||||
|
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
||||||
|
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
||||||
|
}
|
||||||
|
|
||||||
|
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||||
|
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
||||||
|
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||||
|
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||||
|
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
@ -1,126 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
|
|
||||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
|
||||||
#include "tensorflow/core/platform/casts.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Unified Execution APIs for Eager and tracing backends.
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
struct AbstractTensor {
|
|
||||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
|
||||||
explicit AbstractTensor(AbstractTensorKind kind) : k(kind) {}
|
|
||||||
AbstractTensorKind getKind() const { return k; }
|
|
||||||
virtual ~AbstractTensor() = default;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const AbstractTensorKind k;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct OutputList {
|
|
||||||
std::vector<AbstractTensor*> outputs;
|
|
||||||
int expected_num_outputs = -1;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AbstractFunction {
|
|
||||||
enum AbstractFunctionKind { kGraphFunc };
|
|
||||||
explicit AbstractFunction(AbstractFunctionKind kind) : k(kind) {}
|
|
||||||
AbstractFunctionKind getKind() const { return k; }
|
|
||||||
virtual ~AbstractFunction() = default;
|
|
||||||
|
|
||||||
// Temporary API till we figure the right abstraction for AbstractFunction
|
|
||||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const AbstractFunctionKind k;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AbstractOp {
|
|
||||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
|
||||||
// supported in mobile builds.
|
|
||||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
|
||||||
explicit AbstractOp(AbstractOpKind kind) : k(kind) {}
|
|
||||||
AbstractOpKind getKind() const { return k; }
|
|
||||||
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
|
|
||||||
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
|
|
||||||
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
|
|
||||||
TF_Status* s) = 0;
|
|
||||||
virtual ~AbstractOp() {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const AbstractOpKind k;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ExecutionContext {
|
|
||||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
|
||||||
// supported in mobile builds.
|
|
||||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
|
||||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
|
||||||
ExecutionContextKind getKind() const { return k; }
|
|
||||||
|
|
||||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
|
||||||
TF_Status* s) = 0;
|
|
||||||
virtual AbstractOp* CreateOperation() = 0;
|
|
||||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
|
||||||
virtual ~ExecutionContext() = default;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ExecutionContextKind k;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
|
||||||
// C++ implementation, and back.
|
|
||||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
|
||||||
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
|
||||||
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
|
||||||
} \
|
|
||||||
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
|
||||||
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
|
||||||
} \
|
|
||||||
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
|
||||||
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
|
||||||
} \
|
|
||||||
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
|
||||||
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
|
||||||
}
|
|
||||||
|
|
||||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
|
||||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
|
||||||
|
|
||||||
template <typename T, typename S>
|
|
||||||
T* dynamic_cast_helper(S source) {
|
|
||||||
if (source->getKind() != T::kKind) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return tensorflow::down_cast<T*>(source);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace internal
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
|
|
@ -15,17 +15,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
|
||||||
#include <string.h>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/cc/profiler/profiler.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
|
||||||
#include "tensorflow/core/platform/str_util.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user