Refactor the unified C API: split the eager/graph implementation in different files
This is the last stage of the refactoring, after this: - the generic implementation remains in `c_api_unified_experimental.cc`: this contains the implementation for all the C API operating on abstract types. For each type, the C binding is unwrapping (casting to C++ abstract class) and dispatching to virtual methods on the C++ abstract classes. - The eager implementation of the abstract classes is in `c_api_unified_experimental_eager.cc`: this implements the AbstractFunction, AbstractOp, AbstractTensor, and ExecutionContext classes for eager execution. - The graph implementation of these same abstract classes is in `c_api_unified_experimental_graph.cc`. Some C APIs specific to eager or graph are implemented in the respective files. PiperOrigin-RevId: 308423268 Change-Id: I354f948382c078d19d5988775226b5ce4ca2a9d0
This commit is contained in:
parent
81d7cee332
commit
be3e3961f8
@ -367,6 +367,8 @@ tf_cuda_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.cc",
|
"c_api_experimental.cc",
|
||||||
"c_api_unified_experimental.cc",
|
"c_api_unified_experimental.cc",
|
||||||
|
"c_api_unified_experimental_eager.cc",
|
||||||
|
"c_api_unified_experimental_graph.cc",
|
||||||
"c_api_unified_experimental_private.h",
|
"c_api_unified_experimental_private.h",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
@ -28,305 +28,20 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/strcat.h"
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
using tensorflow::internal::AbstractFunction;
|
|
||||||
using tensorflow::internal::AbstractOp;
|
|
||||||
using tensorflow::internal::AbstractTensor;
|
|
||||||
using tensorflow::internal::dynamic_cast_helper;
|
|
||||||
using tensorflow::internal::ExecutionContext;
|
|
||||||
using tensorflow::internal::OutputList;
|
using tensorflow::internal::OutputList;
|
||||||
using tensorflow::internal::unwrap;
|
using tensorflow::internal::unwrap;
|
||||||
using tensorflow::internal::wrap;
|
using tensorflow::internal::wrap;
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||||
|
|
||||||
class TF_GraphContext;
|
|
||||||
class TF_EagerContext;
|
|
||||||
|
|
||||||
struct EagerTensor : public AbstractTensor {
|
|
||||||
TFE_TensorHandle* t = nullptr;
|
|
||||||
EagerTensor() : AbstractTensor(kKind) {}
|
|
||||||
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
|
||||||
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
|
||||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GraphTensor : public AbstractTensor {
|
|
||||||
TF_Output output{};
|
|
||||||
TF_GraphContext* ctx = nullptr;
|
|
||||||
GraphTensor() : AbstractTensor(kKind) {}
|
|
||||||
GraphTensor(TF_Output output, TF_GraphContext* ctx)
|
|
||||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
|
||||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
|
||||||
};
|
|
||||||
|
|
||||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||||
return wrap(unwrap(c)->CreateOperation());
|
return wrap(unwrap(c)->CreateOperation());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
||||||
|
|
||||||
class TF_GraphOp : public AbstractOp {
|
|
||||||
public:
|
|
||||||
explicit TF_GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
|
||||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
|
||||||
if (op_) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_FAILED_PRECONDITION,
|
|
||||||
absl::StrCat("SetOpType called on already built op.").c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (op_name_ != nullptr) {
|
|
||||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
|
||||||
op_name_ = nullptr;
|
|
||||||
} else {
|
|
||||||
op_type_ = op_type;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
|
||||||
if (op_) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_FAILED_PRECONDITION,
|
|
||||||
absl::StrCat("SetOpName called on already built op.").c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (op_type_ != nullptr) {
|
|
||||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
|
||||||
op_type_ = nullptr;
|
|
||||||
} else {
|
|
||||||
op_name_ = op_name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
|
||||||
TF_Status* s) override {
|
|
||||||
if (!op_) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_FAILED_PRECONDITION,
|
|
||||||
"op_type and op_name must be specified before specifying attrs.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TF_SetAttrType(op_.get(), attr_name, value);
|
|
||||||
}
|
|
||||||
~TF_GraphOp() override {}
|
|
||||||
|
|
||||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class TF_GraphContext; // For access to op_.
|
|
||||||
TF_Graph* g_;
|
|
||||||
std::unique_ptr<TF_OperationDescription> op_;
|
|
||||||
// Hold `op_type` and `op_name` till both are available since we need both
|
|
||||||
// to build a graph operation.
|
|
||||||
const char* op_type_ = nullptr;
|
|
||||||
const char* op_name_ = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TF_EagerOp : public AbstractOp {
|
|
||||||
public:
|
|
||||||
explicit TF_EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
|
||||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
|
||||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
|
||||||
}
|
|
||||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
|
||||||
// Name is ignored in eager mode.
|
|
||||||
}
|
|
||||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
|
||||||
TF_Status* s) override {
|
|
||||||
if (op_ == nullptr) {
|
|
||||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
|
||||||
"op_type must be specified before specifying attrs.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_OpSetAttrType(op_, attr_name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
~TF_EagerOp() override { TFE_DeleteOp(op_); }
|
|
||||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class TF_EagerContext; // For access to op_.
|
|
||||||
TFE_Op* op_ = nullptr;
|
|
||||||
TFE_Context* ctx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GraphFunction : public AbstractFunction {
|
|
||||||
TF_Function* func = nullptr;
|
|
||||||
GraphFunction() : AbstractFunction(kKind) {}
|
|
||||||
explicit GraphFunction(TF_Function* func)
|
|
||||||
: AbstractFunction(kKind), func(func) {}
|
|
||||||
~GraphFunction() override {
|
|
||||||
if (func) TF_DeleteFunction(func);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
|
||||||
|
|
||||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TF_EagerContext : public ExecutionContext {
|
|
||||||
public:
|
|
||||||
TF_EagerContext() : ExecutionContext(kKind) {}
|
|
||||||
|
|
||||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
|
||||||
eager_ctx_ = TFE_NewContext(options, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
|
||||||
return new TF_EagerOp(eager_ctx_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
|
||||||
TF_Status* s) override {
|
|
||||||
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
|
||||||
if (eager_op == nullptr) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto* tfe_op = eager_op->op_;
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto* eager_tensor = dynamic_cast_helper<const EagerTensor>(inputs[i]);
|
|
||||||
if (!eager_tensor) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
}
|
|
||||||
if (o->expected_num_outputs == -1) {
|
|
||||||
string msg =
|
|
||||||
"The number of outputs must be provided in eager mode. Use "
|
|
||||||
"TF_OutputListSetNumOutputs.";
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
|
||||||
int num_retvals = o->expected_num_outputs;
|
|
||||||
retvals.resize(num_retvals);
|
|
||||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
|
||||||
if (TF_GetCode(s) != TF_OK) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_retvals);
|
|
||||||
for (int i = 0; i < num_retvals; ++i) {
|
|
||||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
|
||||||
auto* func = afunc->GetTfFunction(s);
|
|
||||||
if (!func) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
|
||||||
|
|
||||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend TFE_Context* TF_ExecutionContextGetTFEContext(
|
|
||||||
TF_ExecutionContext* ctx);
|
|
||||||
TFE_Context* eager_ctx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
||||||
|
|
||||||
class TF_GraphContext : public ExecutionContext {
|
|
||||||
public:
|
|
||||||
TF_GraphContext()
|
|
||||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
|
||||||
return new TF_GraphOp(graph_.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
|
||||||
AbstractTensor* const* inputs, OutputList* o,
|
|
||||||
TF_Status* s) override {
|
|
||||||
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
|
||||||
if (graph_op == nullptr) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto* tf_opdesc = graph_op->op_.release();
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto* graph_tensor = dynamic_cast_helper<GraphTensor>(inputs[i]);
|
|
||||||
if (!graph_tensor) {
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
|
||||||
"Capturing eager tensors is not supported yet.");
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
if (graph_tensor->ctx != this) {
|
|
||||||
TF_SetStatus(
|
|
||||||
s, TF_INVALID_ARGUMENT,
|
|
||||||
"Capturing tensors from other graphs is not supported yet.");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TF_AddInput(tf_opdesc, graph_tensor->output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
|
||||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
|
||||||
graph_op->op_ = nullptr;
|
|
||||||
if (TF_GetCode(s) != TF_OK) return;
|
|
||||||
int num_outputs = TF_OperationNumOutputs(operation);
|
|
||||||
o->outputs.clear();
|
|
||||||
o->outputs.reserve(num_outputs);
|
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
|
||||||
const GraphTensor* inputs, int num_outputs,
|
|
||||||
const GraphTensor* outputs, TF_Status* status) const {
|
|
||||||
std::vector<TF_Output> graph_inputs;
|
|
||||||
graph_inputs.resize(num_inputs);
|
|
||||||
std::vector<TF_Output> graph_outputs;
|
|
||||||
graph_outputs.resize(num_outputs);
|
|
||||||
for (int i = 0; i < num_inputs; i++) {
|
|
||||||
graph_inputs[i] = inputs[i].output;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_outputs; i++) {
|
|
||||||
graph_outputs[i] = outputs[i].output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
|
||||||
graph_inputs.size(), graph_inputs.data(),
|
|
||||||
graph_outputs.size(), graph_outputs.data(),
|
|
||||||
nullptr, nullptr, fn_name, status);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
|
||||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
|
||||||
"Registering graph functions has not been implemented yet.");
|
|
||||||
}
|
|
||||||
|
|
||||||
~TF_GraphContext() override {}
|
|
||||||
|
|
||||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
|
||||||
return wrap(new TF_GraphContext());
|
|
||||||
}
|
|
||||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
|
||||||
TF_Status* s) {
|
|
||||||
auto* ctx = new TF_EagerContext();
|
|
||||||
ctx->Build(options, s);
|
|
||||||
return wrap(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||||
@ -362,32 +77,6 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
|||||||
unwrap(o), s);
|
unwrap(o), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
|
||||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
|
||||||
const TF_AbstractTensor* inputs, int num_outputs,
|
|
||||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
|
||||||
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(unwrap(fn_body));
|
|
||||||
if (graph_ctx == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
|
||||||
"fn_body is not a TF_GraphContext.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto* graph_inputs = dynamic_cast_helper<const GraphTensor>(unwrap(inputs));
|
|
||||||
if (!graph_inputs) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto* graph_outputs = dynamic_cast_helper<const GraphTensor>(unwrap(outputs));
|
|
||||||
if (!graph_outputs) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
GraphFunction* func = new GraphFunction;
|
|
||||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
|
||||||
num_outputs, graph_outputs, status);
|
|
||||||
return wrap(func);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||||
delete unwrap(func);
|
delete unwrap(func);
|
||||||
}
|
}
|
||||||
@ -397,24 +86,3 @@ void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
|||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
|
||||||
TF_Status* s) {
|
|
||||||
return wrap(new EagerTensor(t));
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
|
||||||
TF_Status* s) {
|
|
||||||
auto* eager_tensor = dynamic_cast_helper<EagerTensor>(unwrap(at));
|
|
||||||
if (!eager_tensor) {
|
|
||||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
|
||||||
reinterpret_cast<uintptr_t>(at));
|
|
||||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return eager_tensor->t;
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
|
||||||
return dynamic_cast_helper<TF_EagerContext>(unwrap(ctx))->eager_ctx_;
|
|
||||||
}
|
|
||||||
|
176
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
176
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
|
||||||
|
using tensorflow::string;
|
||||||
|
using tensorflow::internal::AbstractFunction;
|
||||||
|
using tensorflow::internal::AbstractOp;
|
||||||
|
using tensorflow::internal::AbstractTensor;
|
||||||
|
using tensorflow::internal::dynamic_cast_helper;
|
||||||
|
using tensorflow::internal::ExecutionContext;
|
||||||
|
using tensorflow::internal::OutputList;
|
||||||
|
using tensorflow::internal::unwrap;
|
||||||
|
using tensorflow::internal::wrap;
|
||||||
|
|
||||||
|
class TF_EagerContext;
|
||||||
|
|
||||||
|
struct EagerTensor : public AbstractTensor {
|
||||||
|
TFE_TensorHandle* t = nullptr;
|
||||||
|
EagerTensor() : AbstractTensor(kKind) {}
|
||||||
|
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
||||||
|
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
||||||
|
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TF_EagerOp : public AbstractOp {
|
||||||
|
public:
|
||||||
|
explicit TF_EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
||||||
|
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||||
|
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||||
|
}
|
||||||
|
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||||
|
// Name is ignored in eager mode.
|
||||||
|
}
|
||||||
|
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||||
|
TF_Status* s) override {
|
||||||
|
if (op_ == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||||
|
"op_type must be specified before specifying attrs.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TFE_OpSetAttrType(op_, attr_name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
~TF_EagerOp() override { TFE_DeleteOp(op_); }
|
||||||
|
static constexpr AbstractOpKind kKind = kEagerOp;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class TF_EagerContext; // For access to op_.
|
||||||
|
TFE_Op* op_ = nullptr;
|
||||||
|
TFE_Context* ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TF_EagerContext : public ExecutionContext {
|
||||||
|
public:
|
||||||
|
TF_EagerContext() : ExecutionContext(kKind) {}
|
||||||
|
|
||||||
|
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||||
|
eager_ctx_ = TFE_NewContext(options, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractOp* CreateOperation() override {
|
||||||
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
|
return new TF_EagerOp(eager_ctx_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||||
|
AbstractTensor* const* inputs, OutputList* o,
|
||||||
|
TF_Status* s) override {
|
||||||
|
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
||||||
|
if (eager_op == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Unable to cast AbstractOp to TF_EagerOp.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto* tfe_op = eager_op->op_;
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
auto* eager_tensor = dynamic_cast_helper<const EagerTensor>(inputs[i]);
|
||||||
|
if (!eager_tensor) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
}
|
||||||
|
if (o->expected_num_outputs == -1) {
|
||||||
|
string msg =
|
||||||
|
"The number of outputs must be provided in eager mode. Use "
|
||||||
|
"TF_OutputListSetNumOutputs.";
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||||
|
int num_retvals = o->expected_num_outputs;
|
||||||
|
retvals.resize(num_retvals);
|
||||||
|
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||||
|
if (TF_GetCode(s) != TF_OK) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
o->outputs.clear();
|
||||||
|
o->outputs.reserve(num_retvals);
|
||||||
|
for (int i = 0; i < num_retvals; ++i) {
|
||||||
|
o->outputs.push_back(new EagerTensor(retvals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||||
|
auto* func = afunc->GetTfFunction(s);
|
||||||
|
if (!func) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||||
|
|
||||||
|
static constexpr ExecutionContextKind kKind = kEagerContext;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend TFE_Context* TF_ExecutionContextGetTFEContext(
|
||||||
|
TF_ExecutionContext* ctx);
|
||||||
|
TFE_Context* eager_ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||||
|
TF_Status* s) {
|
||||||
|
auto* ctx = new TF_EagerContext();
|
||||||
|
ctx->Build(options, s);
|
||||||
|
return wrap(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||||
|
TF_Status* s) {
|
||||||
|
return wrap(new EagerTensor(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||||
|
TF_Status* s) {
|
||||||
|
auto* eager_tensor = dynamic_cast_helper<EagerTensor>(unwrap(at));
|
||||||
|
if (!eager_tensor) {
|
||||||
|
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||||
|
reinterpret_cast<uintptr_t>(at));
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return eager_tensor->t;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||||
|
return dynamic_cast_helper<TF_EagerContext>(unwrap(ctx))->eager_ctx_;
|
||||||
|
}
|
229
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
229
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
|
||||||
|
using tensorflow::string;
|
||||||
|
using tensorflow::internal::AbstractFunction;
|
||||||
|
using tensorflow::internal::AbstractOp;
|
||||||
|
using tensorflow::internal::AbstractTensor;
|
||||||
|
using tensorflow::internal::dynamic_cast_helper;
|
||||||
|
using tensorflow::internal::ExecutionContext;
|
||||||
|
using tensorflow::internal::OutputList;
|
||||||
|
using tensorflow::internal::unwrap;
|
||||||
|
using tensorflow::internal::wrap;
|
||||||
|
|
||||||
|
class TF_GraphContext;
|
||||||
|
|
||||||
|
struct GraphTensor : public AbstractTensor {
|
||||||
|
TF_Output output{};
|
||||||
|
TF_GraphContext* ctx = nullptr;
|
||||||
|
GraphTensor() : AbstractTensor(kKind) {}
|
||||||
|
GraphTensor(TF_Output output, TF_GraphContext* ctx)
|
||||||
|
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
||||||
|
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TF_GraphOp : public AbstractOp {
|
||||||
|
public:
|
||||||
|
explicit TF_GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
||||||
|
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||||
|
if (op_) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_FAILED_PRECONDITION,
|
||||||
|
absl::StrCat("SetOpType called on already built op.").c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op_name_ != nullptr) {
|
||||||
|
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||||
|
op_name_ = nullptr;
|
||||||
|
} else {
|
||||||
|
op_type_ = op_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||||
|
if (op_) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_FAILED_PRECONDITION,
|
||||||
|
absl::StrCat("SetOpName called on already built op.").c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op_type_ != nullptr) {
|
||||||
|
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||||
|
op_type_ = nullptr;
|
||||||
|
} else {
|
||||||
|
op_name_ = op_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||||
|
TF_Status* s) override {
|
||||||
|
if (!op_) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_FAILED_PRECONDITION,
|
||||||
|
"op_type and op_name must be specified before specifying attrs.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TF_SetAttrType(op_.get(), attr_name, value);
|
||||||
|
}
|
||||||
|
~TF_GraphOp() override {}
|
||||||
|
|
||||||
|
static constexpr AbstractOpKind kKind = kGraphOp;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class TF_GraphContext; // For access to op_.
|
||||||
|
TF_Graph* g_;
|
||||||
|
std::unique_ptr<TF_OperationDescription> op_;
|
||||||
|
// Hold `op_type` and `op_name` till both are available since we need both
|
||||||
|
// to build a graph operation.
|
||||||
|
const char* op_type_ = nullptr;
|
||||||
|
const char* op_name_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GraphFunction : public AbstractFunction {
|
||||||
|
TF_Function* func = nullptr;
|
||||||
|
GraphFunction() : AbstractFunction(kKind) {}
|
||||||
|
explicit GraphFunction(TF_Function* func)
|
||||||
|
: AbstractFunction(kKind), func(func) {}
|
||||||
|
~GraphFunction() override {
|
||||||
|
if (func) TF_DeleteFunction(func);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
||||||
|
|
||||||
|
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TF_GraphContext : public ExecutionContext {
|
||||||
|
public:
|
||||||
|
TF_GraphContext()
|
||||||
|
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||||
|
|
||||||
|
AbstractOp* CreateOperation() override {
|
||||||
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
|
return new TF_GraphOp(graph_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||||
|
AbstractTensor* const* inputs, OutputList* o,
|
||||||
|
TF_Status* s) override {
|
||||||
|
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
||||||
|
if (graph_op == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto* tf_opdesc = graph_op->op_.release();
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
auto* graph_tensor = dynamic_cast_helper<GraphTensor>(inputs[i]);
|
||||||
|
if (!graph_tensor) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Capturing eager tensors is not supported yet.");
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
if (graph_tensor->ctx != this) {
|
||||||
|
TF_SetStatus(
|
||||||
|
s, TF_INVALID_ARGUMENT,
|
||||||
|
"Capturing tensors from other graphs is not supported yet.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TF_AddInput(tf_opdesc, graph_tensor->output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||||
|
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||||
|
graph_op->op_ = nullptr;
|
||||||
|
if (TF_GetCode(s) != TF_OK) return;
|
||||||
|
int num_outputs = TF_OperationNumOutputs(operation);
|
||||||
|
o->outputs.clear();
|
||||||
|
o->outputs.reserve(num_outputs);
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||||
|
const GraphTensor* inputs, int num_outputs,
|
||||||
|
const GraphTensor* outputs, TF_Status* status) const {
|
||||||
|
std::vector<TF_Output> graph_inputs;
|
||||||
|
graph_inputs.resize(num_inputs);
|
||||||
|
std::vector<TF_Output> graph_outputs;
|
||||||
|
graph_outputs.resize(num_outputs);
|
||||||
|
for (int i = 0; i < num_inputs; i++) {
|
||||||
|
graph_inputs[i] = inputs[i].output;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_outputs; i++) {
|
||||||
|
graph_outputs[i] = outputs[i].output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||||
|
graph_inputs.size(), graph_inputs.data(),
|
||||||
|
graph_outputs.size(), graph_outputs.data(),
|
||||||
|
nullptr, nullptr, fn_name, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||||
|
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||||
|
"Registering graph functions has not been implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
~TF_GraphContext() override {}
|
||||||
|
|
||||||
|
static constexpr ExecutionContextKind kKind = kGraphContext;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||||
|
return wrap(new TF_GraphContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||||
|
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||||
|
const TF_AbstractTensor* inputs, int num_outputs,
|
||||||
|
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||||
|
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(unwrap(fn_body));
|
||||||
|
if (graph_ctx == nullptr) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||||
|
"fn_body is not a TF_GraphContext.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto* graph_inputs = dynamic_cast_helper<const GraphTensor>(unwrap(inputs));
|
||||||
|
if (!graph_inputs) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto* graph_outputs = dynamic_cast_helper<const GraphTensor>(unwrap(outputs));
|
||||||
|
if (!graph_outputs) {
|
||||||
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
GraphFunction* func = new GraphFunction;
|
||||||
|
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||||
|
num_outputs, graph_outputs, status);
|
||||||
|
return wrap(func);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user