Add support for setting up a TF_OutputList from the client and use it to build function with multiple results
PiperOrigin-RevId: 311585364 Change-Id: I5245fd0f5e5c0e8e7e22350d970c508e0154d59b
This commit is contained in:
parent
ba43780830
commit
215616fddc
@ -127,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
|
|||||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||||
return wrap(unwrap(o)->outputs[i]);
|
return wrap(unwrap(o)->outputs[i]);
|
||||||
}
|
}
|
||||||
|
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||||
|
TF_Status* s) {
|
||||||
|
unwrap(o)->outputs.push_back(unwrap(tensor));
|
||||||
|
}
|
||||||
|
|
||||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||||
TF_Status* s) {
|
TF_Status* s) {
|
||||||
|
@ -88,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
|||||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||||
|
|
||||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||||
// an operation.
|
// an operation, or provided to create a function.
|
||||||
// It just lets us not specify the number of outputs of an operation
|
// When executing an operation in an eager context, the expected number of
|
||||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
|
||||||
// it allows for generic code.
|
|
||||||
// TODO(aminim): the description above isn't clear with respect to
|
|
||||||
// TF_OutputListNumOutputs and the current eager implementation which requires
|
|
||||||
// the number of outputs to be set by the client.
|
|
||||||
typedef struct TF_OutputList TF_OutputList;
|
typedef struct TF_OutputList TF_OutputList;
|
||||||
TF_OutputList* TF_NewOutputList();
|
TF_OutputList* TF_NewOutputList();
|
||||||
void TF_DeleteOutputList(TF_OutputList* o);
|
void TF_DeleteOutputList(TF_OutputList* o);
|
||||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
// Prepare tracing to the expected number of output for an operation.
|
||||||
|
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
|
||||||
|
// Return the number of outputs in the list.
|
||||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||||
|
// Return the `i`th output in the list.
|
||||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||||
|
// Append a tensor at the end of the output list, growing its size by one.
|
||||||
|
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||||
|
TF_Status*);
|
||||||
|
|
||||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||||
// capture some inputs and then add a node in the graph. The output tensors are
|
// capture some inputs and then add a node in the graph. The output tensors are
|
||||||
|
@ -139,6 +139,10 @@ class GraphContext : public ExecutionContext {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto* tf_opdesc = graph_op->op_.release();
|
auto* tf_opdesc = graph_op->op_.release();
|
||||||
|
if (tf_opdesc == nullptr) {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||||
if (!graph_tensor) {
|
if (!graph_tensor) {
|
||||||
|
@ -169,7 +169,152 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
|||||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_Status* s = status.get();
|
||||||
|
|
||||||
|
// Start a new function / execution context.
|
||||||
|
string fn_name = "two_adds";
|
||||||
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Create a first "Add" computing `arg0 + arg1`.
|
||||||
|
TF_AbstractTensor* add_output1;
|
||||||
|
{
|
||||||
|
// Build an abstract operation, inputs and output.
|
||||||
|
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||||
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
// Trace the operation now (create a node in the graph).
|
||||||
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_DeleteAbstractOp(add_op);
|
||||||
|
// Extract the resulting tensor.
|
||||||
|
add_output1 = TF_OutputListGet(add_outputs, 0);
|
||||||
|
TF_DeleteOutputList(add_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same with a second "Add" computing `arg1 + arg1`.
|
||||||
|
TF_AbstractTensor* add_output2;
|
||||||
|
{
|
||||||
|
// Build an abstract operation, inputs and output.
|
||||||
|
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||||
|
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||||
|
// Trace the operation now (create a node in the graph).
|
||||||
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_DeleteAbstractOp(add_op);
|
||||||
|
// Extract the resulting tensor.
|
||||||
|
add_output2 = TF_OutputListGet(add_outputs, 0);
|
||||||
|
TF_DeleteOutputList(add_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize the function by providing the returned values.
|
||||||
|
TF_AbstractFunction* func;
|
||||||
|
{
|
||||||
|
// We want to return the output of both add operations, create a new list
|
||||||
|
// and populate it.
|
||||||
|
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||||
|
TF_OutputListPushBack(func_outputs, add_output1, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_OutputListPushBack(func_outputs, add_output2, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_DeleteOutputList(func_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We traced so far this function:
|
||||||
|
*
|
||||||
|
* def two_adds(a, b):
|
||||||
|
* my_add1 = a + b
|
||||||
|
* my_add2 = b + b
|
||||||
|
* return my_add1, my_add2
|
||||||
|
*
|
||||||
|
* Now we will execute this function with an eager context:
|
||||||
|
*
|
||||||
|
* output1, output2 = two_adds(2.0, 3.0)
|
||||||
|
*
|
||||||
|
* and check that we got 5.0 and 6.0 as results.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Build eager context.
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TF_ExecutionContext* eager_execution_ctx =
|
||||||
|
TF_NewEagerExecutionContext(opts, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Build the abstract op to run the function.
|
||||||
|
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||||
|
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Build two abstract input tensors as function arguments.
|
||||||
|
std::vector<TF_AbstractTensor*> func_args;
|
||||||
|
{
|
||||||
|
TFE_Context* eager_ctx =
|
||||||
|
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||||
|
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||||
|
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
|
||||||
|
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||||
|
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||||
|
eager_execution_ctx, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_DeleteAbstractOp(fn_op);
|
||||||
|
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
|
||||||
|
float results[2];
|
||||||
|
for (int idx = 0; idx < 2; ++idx) {
|
||||||
|
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||||
|
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
|
||||||
|
TF_DeleteTensor(f_t);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(results[0], 5.0);
|
||||||
|
ASSERT_EQ(results[1], 6.0);
|
||||||
|
|
||||||
|
for (int idx = 0; idx < 2; ++idx) {
|
||||||
|
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||||
|
TF_DeleteAbstractTensor(result);
|
||||||
|
}
|
||||||
|
TF_DeleteOutputList(func_outputs);
|
||||||
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
|
TF_DeleteAbstractFunction(func);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
Loading…
Reference in New Issue
Block a user