Refactor TF_OutputList in the C unified API to align with the other C++ type

PiperOrigin-RevId: 308422771
Change-Id: Idbf72e7c66f038016e3ef06d6e7c0b787c032d97
This commit is contained in:
Mehdi Amini 2020-04-25 10:03:39 -07:00 committed by TensorFlower Gardener
parent c1603be930
commit 4e2e046484
2 changed files with 23 additions and 16 deletions

View File

@ -32,6 +32,7 @@ using tensorflow::internal::AbstractOp;
using tensorflow::internal::AbstractTensor;
using tensorflow::internal::dynamic_cast_helper;
using tensorflow::internal::ExecutionContext;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
using tensorflow::internal::wrap;
@ -146,11 +147,6 @@ class TF_EagerOp : public AbstractOp {
TFE_Context* ctx_;
};
struct TF_OutputList {
std::vector<TF_AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
struct TF_AbstractFunction {
TF_Function* func = nullptr;
@ -171,7 +167,7 @@ class TF_EagerContext : public ExecutionContext {
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, TF_OutputList* o,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* eager_op = dynamic_cast_helper<TF_EagerOp>(op);
if (eager_op == nullptr) {
@ -207,7 +203,7 @@ class TF_EagerContext : public ExecutionContext {
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
o->outputs.push_back(wrap(new EagerTensor(retvals[i])));
o->outputs.push_back(new EagerTensor(retvals[i]));
}
}
@ -238,7 +234,7 @@ class TF_GraphContext : public ExecutionContext {
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, TF_OutputList* o,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* graph_op = dynamic_cast_helper<TF_GraphOp>(op);
if (graph_op == nullptr) {
@ -271,7 +267,7 @@ class TF_GraphContext : public ExecutionContext {
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
o->outputs.push_back(wrap(new GraphTensor({operation, i}, this)));
o->outputs.push_back(new GraphTensor({operation, i}, this));
}
}
@ -318,15 +314,17 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
return wrap(ctx);
}
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
unwrap(o)->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) {
return unwrap(o)->outputs.size();
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
return wrap(unwrap(o)->outputs[i]);
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
@ -347,7 +345,8 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs), o, s);
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
unwrap(o), s);
}
TF_AbstractFunction* TF_ExecutionContextToFunction(

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
#include <vector>
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/core/platform/casts.h"
@ -36,6 +38,11 @@ struct AbstractTensor {
const AbstractTensorKind k;
};
struct OutputList {
std::vector<AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
struct AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
@ -60,7 +67,7 @@ struct ExecutionContext {
ExecutionContextKind getKind() const { return k; }
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, TF_OutputList* o,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
virtual AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
@ -89,6 +96,7 @@ struct ExecutionContext {
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
template <typename T, typename S>
T* dynamic_cast_helper(S source) {