Improve documentation for the unified C API and perform some minor cleanups

- Document most of the APIs,
- move the C++ implementation in the tensorflow::internal
namespace,
- separate more cleanly the C API entry points,
- Drop the TF_ prefix from the C++ classes where it was missed during
  refactoring.
- Use IWYU to cleanup the includes
- Rename c_api_unified_experimental_private.h into
  c_api_unified_experimental_internal.h for align with the convention in the
  directory.

PiperOrigin-RevId: 308699592
Change-Id: Id7198f73a63c400ce94bdeefeef010f441622a88
This commit is contained in:
Mehdi Amini 2020-04-27 14:32:29 -07:00 committed by TensorFlower Gardener
parent c6eaf562e7
commit ea74b01d80
8 changed files with 357 additions and 247 deletions

View File

@ -369,7 +369,7 @@ tf_cuda_library(
"c_api_unified_experimental.cc",
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_private.h",
"c_api_unified_experimental_internal.h",
],
hdrs = [
"c_api_experimental.h",
@ -466,6 +466,7 @@ tf_cuda_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",

View File

@ -13,24 +13,28 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <vector>
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
using tensorflow::internal::wrap;
// =============================================================================
// Public C API entry points
//
// These are only the generic entry points for the C API. This file does not
// have any visibility into the graph/eager implementation and is only providing
// C bindings to the abstract classes defined in the
// c_api_unified_experimental_internal.h header.
//
// =============================================================================
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
@ -34,23 +35,34 @@ extern "C" {
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
// of gradient tapes, etc.
typedef struct TF_ExecutionContext TF_ExecutionContext;
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
// type of eager and graph tensors.
// type of eager and graph tensors. It is also the result of executing an
// operation.
typedef struct TF_AbstractTensor TF_AbstractTensor;
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
// Creates a context for tracing the execution of operations into a function.
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
// Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
void TF_DeleteAbstractOp(TF_AbstractOp*);
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
@ -62,9 +74,16 @@ void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation.
// It just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
// TODO(aminim): the description above isn't clear with respect to
// TF_OutputListNumOutputs and the current eager implementation which requires
// the number of outputs to be set by the client.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
@ -72,27 +91,32 @@ void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
// which happens to be active.
// capture some inputs and then add a node in the graph. The output tensors are
// returned through the provided TF_OutputList.
// Any active tape will observe the effects of this execution.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The returned TF_GraphToFunction must be deleted by the client.
// TODO(aminim): clarify the contract on the state of the context after this
// call.
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
// Register the function with the given context. This is particularly useful for
// making a function available to an eager context.
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// -----------------------------------------------------------------------------
// APIs specific to Eager and graph modes
// APIs specific to Eager modes
// -----------------------------------------------------------------------------
// Temporary APIs till we figure out how to create scalar valued Eager

View File

@ -13,32 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
using tensorflow::internal::AbstractFunction;
using tensorflow::internal::AbstractOp;
using tensorflow::internal::AbstractTensor;
using tensorflow::internal::dynamic_cast_helper;
using tensorflow::internal::ExecutionContext;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
using tensorflow::internal::wrap;
class TF_EagerContext;
namespace tensorflow {
namespace internal {
// Simple wrapper over a TFE_TensorHandle
struct EagerTensor : public AbstractTensor {
TFE_TensorHandle* t = nullptr;
EagerTensor() : AbstractTensor(kKind) {}
@ -47,9 +39,10 @@ struct EagerTensor : public AbstractTensor {
static constexpr AbstractTensorKind kKind = kEagerTensor;
};
class TF_EagerOp : public AbstractOp {
// Simple wrapper over a TFE_Op
class EagerOp : public AbstractOp {
public:
explicit TF_EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
@ -66,18 +59,19 @@ class TF_EagerOp : public AbstractOp {
TFE_OpSetAttrType(op_, attr_name, value);
}
~TF_EagerOp() override { TFE_DeleteOp(op_); }
~EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = kEagerOp;
private:
friend class TF_EagerContext; // For access to op_.
friend class EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
class TF_EagerContext : public ExecutionContext {
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
class EagerContext : public ExecutionContext {
public:
TF_EagerContext() : ExecutionContext(kKind) {}
EagerContext() : ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
@ -85,13 +79,13 @@ class TF_EagerContext : public ExecutionContext {
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_EagerOp(eager_ctx_);
return new EagerOp(eager_ctx_);
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
auto* eager_op = dyncast<EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_EagerOp.");
@ -100,7 +94,7 @@ class TF_EagerContext : public ExecutionContext {
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
auto* eager_tensor = dynamic_cast_helper<const EagerTensor>(inputs[i]);
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
if (!eager_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
@ -137,34 +131,45 @@ class TF_EagerContext : public ExecutionContext {
TFE_ContextAddFunction(eager_ctx_, func, s);
}
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = kEagerContext;
private:
friend TFE_Context* TF_ExecutionContextGetTFEContext(
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Eager API.
// =============================================================================
using tensorflow::internal::dyncast;
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
TF_Status* s) {
auto* ctx = new TF_EagerContext();
auto* ctx = new tensorflow::internal::EagerContext();
ctx->Build(options, s);
return wrap(ctx);
}
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s) {
return wrap(new EagerTensor(t));
return wrap(new tensorflow::internal::EagerTensor(t));
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
auto* eager_tensor = dynamic_cast_helper<EagerTensor>(unwrap(at));
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
if (!eager_tensor) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
@ -172,5 +177,7 @@ TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
return dynamic_cast_helper<TF_EagerContext>(unwrap(ctx))->eager_ctx_;
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
if (!eager_ctx) return nullptr;
return eager_ctx->eager_ctx_;
}

View File

@ -13,49 +13,45 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/variant.h"
#include <memory>
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
using tensorflow::internal::AbstractFunction;
using tensorflow::internal::AbstractOp;
using tensorflow::internal::AbstractTensor;
using tensorflow::internal::dynamic_cast_helper;
using tensorflow::internal::ExecutionContext;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
using tensorflow::internal::wrap;
class TF_GraphContext;
namespace tensorflow {
namespace internal {
class GraphContext;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
struct GraphTensor : public AbstractTensor {
TF_Output output{};
TF_GraphContext* ctx = nullptr;
GraphContext* ctx = nullptr;
GraphTensor() : AbstractTensor(kKind) {}
GraphTensor(TF_Output output, TF_GraphContext* ctx)
GraphTensor(TF_Output output, GraphContext* ctx)
: AbstractTensor(kKind), output(output), ctx(ctx) {}
static constexpr AbstractTensorKind kKind = kGraphTensor;
};
class TF_GraphOp : public AbstractOp {
// GraphOp wraps and populate a TF_OperationDescription.
class GraphOp : public AbstractOp {
public:
explicit TF_GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpType called on already built op.").c_str());
strings::StrCat("SetOpType called on already built op.").c_str());
return;
}
if (op_name_ != nullptr) {
@ -69,7 +65,7 @@ class TF_GraphOp : public AbstractOp {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
absl::StrCat("SetOpName called on already built op.").c_str());
strings::StrCat("SetOpName called on already built op.").c_str());
return;
}
if (op_type_ != nullptr) {
@ -89,12 +85,12 @@ class TF_GraphOp : public AbstractOp {
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~TF_GraphOp() override {}
~GraphOp() override {}
static constexpr AbstractOpKind kKind = kGraphOp;
private:
friend class TF_GraphContext; // For access to op_.
friend class GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
@ -103,6 +99,7 @@ class TF_GraphOp : public AbstractOp {
const char* op_name_ = nullptr;
};
// GraphFunction is a thin wrapper over a TF_Function.
struct GraphFunction : public AbstractFunction {
TF_Function* func = nullptr;
GraphFunction() : AbstractFunction(kKind) {}
@ -117,20 +114,22 @@ struct GraphFunction : public AbstractFunction {
static constexpr AbstractFunctionKind kKind = kGraphFunc;
};
class TF_GraphContext : public ExecutionContext {
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
// adding them to the graph.
class GraphContext : public ExecutionContext {
public:
TF_GraphContext()
GraphContext()
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new TF_GraphOp(graph_.get());
return new GraphOp(graph_.get());
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
auto* graph_op = dyncast<GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_GraphOp.");
@ -138,7 +137,7 @@ class TF_GraphContext : public ExecutionContext {
}
auto* tf_opdesc = graph_op->op_.release();
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dynamic_cast_helper<GraphTensor>(inputs[i]);
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
@ -190,7 +189,7 @@ class TF_GraphContext : public ExecutionContext {
"Registering graph functions has not been implemented yet.");
}
~TF_GraphContext() override {}
~GraphContext() override {}
static constexpr ExecutionContextKind kKind = kGraphContext;
@ -198,26 +197,23 @@ class TF_GraphContext : public ExecutionContext {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
};
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new TF_GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(unwrap(fn_body));
// Helper that converts the graph currently held in the context into a function.
static AbstractFunction* ExecutionContextToFunction(
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const AbstractTensor* inputs, int num_outputs,
const AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
auto* graph_inputs = dynamic_cast_helper<const GraphTensor>(unwrap(inputs));
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
if (!graph_inputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
return nullptr;
}
auto* graph_outputs = dynamic_cast_helper<const GraphTensor>(unwrap(outputs));
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
if (!graph_outputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
@ -225,5 +221,28 @@ TF_AbstractFunction* TF_ExecutionContextToFunction(
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return wrap(func);
return func;
}
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Graph API.
// =============================================================================
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new tensorflow::internal::GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
unwrap(inputs), num_outputs,
unwrap(outputs), status));
}

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
namespace internal {
// =============================================================================
// Implementation detail for the unified execution APIs for Eager and tracing
// backends (graph/MLIR).
//
// This defines a set of abstract classes that are intended to provide the
// functionality of the opaque C types exposed in the public APIs defined in the
// `c_api_unified_experimental.h` header.
// =============================================================================
// We can't depend on C++ rtti, but we still want to be able to have a safe
// dynamic_cast to provide diagnostics to the user when the API is misused.
// Instead we model RTTI by listing all the possible subclasses for each
// abstract base. Each subclass initializes the base class with the right
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
// utility.
template <typename T, typename S>
T* dyncast(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
// Represents either an EagerTensor or a GraphTensor.
// This base class does not expose any public methods other than to distinguish
// which subclass it actually is. The user is responsible to use the right
// type of AbstractTensor in their context (do not pass an EagerTensor to a
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractTensorKind getKind() const { return kind_; }
virtual ~AbstractTensor() = default;
private:
const AbstractTensorKind kind_;
};
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
// Holds the result of tracing a function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Temporary API till we figure the right abstraction for AbstractFunction.
// At the moment both Eager and Graph needs access to a "TF_Function" object.
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
private:
const AbstractFunctionKind kind_;
};
// An abstract operation describes an operation by its type, name, and
// attributes. It can be "executed" by the context with some input tensors.
// It is allowed to reusing the same abstract operation for multiple execution
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractOpKind getKind() const { return kind_; }
virtual ~AbstractOp() = default;
// Sets the type of the operation (for example `AddV2`).
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
// Sets the name of the operation: this is an optional identifier that is
// not intended to carry semantics and preserved/propagated without
// guarantees.
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
// Add a `TypeAttribute` on the operation.
virtual void SetAttrType(const char* attr_name, TF_DataType value,
TF_Status* s) = 0;
private:
const AbstractOpKind kind_;
};
// This holds the context for the execution: dispatching operations either to an
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:
// Returns which subclass is this instance of.
ExecutionContextKind getKind() const { return k; }
virtual ~ExecutionContext() = default;
// Executes the operation on the provided inputs and populate the OutputList
// with the results. The input tensors must match the current context.
// The effect of "executing" an operation depends on the context: in an Eager
// context it will dispatch it to the runtime for execution, while in a
// tracing context it will add the operation to the current function.
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
private:
const ExecutionContextKind k;
};
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
return reinterpret_cast<CPP_CLASS* const&>(o); \
} \
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
return reinterpret_cast<const CPP_CLASS* const&>(o); \
} \
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
return reinterpret_cast<C_TYPEDEF* const&>(o); \
} \
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
}
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_

View File

@ -1,126 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
#include <vector>
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
namespace internal {
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
struct AbstractTensor {
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
explicit AbstractTensor(AbstractTensorKind kind) : k(kind) {}
AbstractTensorKind getKind() const { return k; }
virtual ~AbstractTensor() = default;
private:
const AbstractTensorKind k;
};
struct OutputList {
std::vector<AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
struct AbstractFunction {
enum AbstractFunctionKind { kGraphFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : k(kind) {}
AbstractFunctionKind getKind() const { return k; }
virtual ~AbstractFunction() = default;
// Temporary API till we figure the right abstraction for AbstractFunction
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
private:
const AbstractFunctionKind k;
};
struct AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum AbstractOpKind { kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : k(kind) {}
AbstractOpKind getKind() const { return k; }
virtual void SetOpType(const char* const op_type, TF_Status* s) = 0;
virtual void SetOpName(const char* const op_name, TF_Status* s) = 0;
virtual void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) = 0;
virtual ~AbstractOp() {}
private:
const AbstractOpKind k;
};
struct ExecutionContext {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
virtual AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
virtual ~ExecutionContext() = default;
private:
const ExecutionContextKind k;
};
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
return reinterpret_cast<CPP_CLASS* const&>(o); \
} \
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
return reinterpret_cast<const CPP_CLASS* const&>(o); \
} \
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
return reinterpret_cast<C_TYPEDEF* const&>(o); \
} \
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
}
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_

View File

@ -15,17 +15,14 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <string.h>
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;