Refactor TF_OutputList in the C unified API to align with the other C++ type
PiperOrigin-RevId: 308422771 Change-Id: Idbf72e7c66f038016e3ef06d6e7c0b787c032d97
This commit is contained in:
parent
c1603be930
commit
4e2e046484
@ -32,6 +32,7 @@ using tensorflow::internal::AbstractOp;
|
||||
using tensorflow::internal::AbstractTensor;
|
||||
using tensorflow::internal::dynamic_cast_helper;
|
||||
using tensorflow::internal::ExecutionContext;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
using tensorflow::internal::wrap;
|
||||
|
||||
@ -146,11 +147,6 @@ class TF_EagerOp : public AbstractOp {
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
struct TF_OutputList {
|
||||
std::vector<TF_AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
struct TF_AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
|
||||
@ -171,7 +167,7 @@ class TF_EagerContext : public ExecutionContext {
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
@ -207,7 +203,7 @@ class TF_EagerContext : public ExecutionContext {
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
o->outputs.push_back(wrap(new EagerTensor(retvals[i])));
|
||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -238,7 +234,7 @@ class TF_GraphContext : public ExecutionContext {
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
@ -271,7 +267,7 @@ class TF_GraphContext : public ExecutionContext {
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
o->outputs.push_back(wrap(new GraphTensor({operation, i}, this)));
|
||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
||||
}
|
||||
}
|
||||
|
||||
@ -318,15 +314,17 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||
return wrap(ctx);
|
||||
}
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
o->expected_num_outputs = num_outputs;
|
||||
unwrap(o)->expected_num_outputs = num_outputs;
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
return unwrap(o)->outputs.size();
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
@ -347,7 +345,8 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs), o, s);
|
||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
|
||||
unwrap(o), s);
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
@ -36,6 +38,11 @@ struct AbstractTensor {
|
||||
const AbstractTensorKind k;
|
||||
};
|
||||
|
||||
struct OutputList {
|
||||
std::vector<AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
struct AbstractOp {
|
||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||
// supported in mobile builds.
|
||||
@ -60,7 +67,7 @@ struct ExecutionContext {
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
|
||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) = 0;
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
|
||||
@ -89,6 +96,7 @@ struct ExecutionContext {
|
||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||
|
||||
template <typename T, typename S>
|
||||
T* dynamic_cast_helper(S source) {
|
||||
|
Loading…
Reference in New Issue
Block a user