From 4e2e04648422b23a07fc9ff807a20e4ef395251b Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 25 Apr 2020 10:03:39 -0700 Subject: [PATCH] Refactor TF_OutputList in the C unified API to align with the other C++ type PiperOrigin-RevId: 308422771 Change-Id: Idbf72e7c66f038016e3ef06d6e7c0b787c032d97 --- .../c/eager/c_api_unified_experimental.cc | 29 +++++++++---------- .../c_api_unified_experimental_private.h | 10 ++++++- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index d111df5aba0..d415ace8a9b 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -32,6 +32,7 @@ using tensorflow::internal::AbstractOp; using tensorflow::internal::AbstractTensor; using tensorflow::internal::dynamic_cast_helper; using tensorflow::internal::ExecutionContext; +using tensorflow::internal::OutputList; using tensorflow::internal::unwrap; using tensorflow::internal::wrap; @@ -146,11 +147,6 @@ class TF_EagerOp : public AbstractOp { TFE_Context* ctx_; }; -struct TF_OutputList { - std::vector outputs; - int expected_num_outputs = -1; -}; - struct TF_AbstractFunction { TF_Function* func = nullptr; @@ -171,7 +167,7 @@ class TF_EagerContext : public ExecutionContext { } void ExecuteOperation(AbstractOp* op, int num_inputs, - AbstractTensor* const* inputs, TF_OutputList* o, + AbstractTensor* const* inputs, OutputList* o, TF_Status* s) override { auto* eager_op = dynamic_cast_helper(op); if (eager_op == nullptr) { @@ -207,7 +203,7 @@ class TF_EagerContext : public ExecutionContext { o->outputs.clear(); o->outputs.reserve(num_retvals); for (int i = 0; i < num_retvals; ++i) { - o->outputs.push_back(wrap(new EagerTensor(retvals[i]))); + o->outputs.push_back(new EagerTensor(retvals[i])); } } @@ -238,7 +234,7 @@ class TF_GraphContext : public ExecutionContext { } void ExecuteOperation(AbstractOp* op, int num_inputs, - AbstractTensor* const* inputs, TF_OutputList* o, + AbstractTensor* const* inputs, OutputList* o, TF_Status* s) override { auto* graph_op = dynamic_cast_helper(op); if (graph_op == nullptr) { @@ -271,7 +267,7 @@ class TF_GraphContext : public ExecutionContext { o->outputs.clear(); o->outputs.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { - o->outputs.push_back(wrap(new GraphTensor({operation, i}, this))); + o->outputs.push_back(new GraphTensor({operation, i}, this)); } } @@ -318,15 +314,17 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options, return wrap(ctx); } -TF_OutputList* TF_NewOutputList() { return new TF_OutputList; } -void TF_DeleteOutputList(TF_OutputList* o) { delete o; } +TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } +void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status* s) { - o->expected_num_outputs = num_outputs; + unwrap(o)->expected_num_outputs = num_outputs; +} +int TF_OutputListNumOutputs(TF_OutputList* o) { + return unwrap(o)->outputs.size(); } -int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); } TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { - return o->outputs[i]; + return wrap(unwrap(o)->outputs[i]); } void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, @@ -347,7 +345,8 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_AbstractTensor* const* inputs, TF_OutputList* o, TF_ExecutionContext* ctx, TF_Status* s) { - unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs), o, s); + unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs), + unwrap(o), s); } TF_AbstractFunction* TF_ExecutionContextToFunction( diff --git a/tensorflow/c/eager/c_api_unified_experimental_private.h b/tensorflow/c/eager/c_api_unified_experimental_private.h index 351637c0083..34b8443f8cc 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_private.h +++ b/tensorflow/c/eager/c_api_unified_experimental_private.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_ #define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_ +#include + #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/core/platform/casts.h" @@ -36,6 +38,11 @@ struct AbstractTensor { const AbstractTensorKind k; }; +struct OutputList { + std::vector outputs; + int expected_num_outputs = -1; +}; + struct AbstractOp { // Needed to implement our own version of RTTI since dynamic_cast is not // supported in mobile builds. @@ -60,7 +67,7 @@ struct ExecutionContext { ExecutionContextKind getKind() const { return k; } virtual void ExecuteOperation(AbstractOp* op, int num_inputs, - AbstractTensor* const* inputs, TF_OutputList* o, + AbstractTensor* const* inputs, OutputList* o, TF_Status* s) = 0; virtual AbstractOp* CreateOperation() = 0; virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0; @@ -89,6 +96,7 @@ struct ExecutionContext { MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext) MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor) MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp) +MAKE_WRAP_UNWRAP(TF_OutputList, OutputList) template T* dynamic_cast_helper(S source) {