Add support for setting up a TF_OutputList from the client and use it to build function with multiple results
PiperOrigin-RevId: 311585364 Change-Id: I5245fd0f5e5c0e8e7e22350d970c508e0154d59b
This commit is contained in:
		
							parent
							
								
									ba43780830
								
							
						
					
					
						commit
						215616fddc
					
				| @ -127,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) { | ||||
| TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { | ||||
|   return wrap(unwrap(o)->outputs[i]); | ||||
| } | ||||
| void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, | ||||
|                            TF_Status* s) { | ||||
|   unwrap(o)->outputs.push_back(unwrap(tensor)); | ||||
| } | ||||
| 
 | ||||
| void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, | ||||
|                             TF_Status* s) { | ||||
|  | ||||
| @ -88,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, | ||||
| void TF_DeleteAbstractTensor(TF_AbstractTensor*); | ||||
| 
 | ||||
| // TF_OutputList holds the list of TF_AbstractTensor that results from executing
 | ||||
| // an operation.
 | ||||
| // It just lets us not specify the number of outputs of an operation
 | ||||
| // beforehand. This forces a memory allocation in the runtime, which is bad, but
 | ||||
| // it allows for generic code.
 | ||||
| // TODO(aminim): the description above isn't clear with respect to
 | ||||
| // TF_OutputListNumOutputs and the current eager implementation which requires
 | ||||
| // the number of outputs to be set by the client.
 | ||||
| // an operation, or provided to create a function.
 | ||||
| // When executing an operation in an eager context, the expected number of
 | ||||
| // outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
 | ||||
| typedef struct TF_OutputList TF_OutputList; | ||||
| TF_OutputList* TF_NewOutputList(); | ||||
| void TF_DeleteOutputList(TF_OutputList* o); | ||||
| void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*); | ||||
| // Prepare tracing to the expected number of output for an operation.
 | ||||
| void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*); | ||||
| // Return the number of outputs in the list.
 | ||||
| int TF_OutputListNumOutputs(TF_OutputList* o); | ||||
| // Return the `i`th output in the list.
 | ||||
| TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); | ||||
| // Append a tensor at the end of the output list, growing its size by one.
 | ||||
| void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, | ||||
|                            TF_Status*); | ||||
| 
 | ||||
| // TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
 | ||||
| // capture some inputs and then add a node in the graph. The output tensors are
 | ||||
|  | ||||
| @ -139,6 +139,10 @@ class GraphContext : public ExecutionContext { | ||||
|       return; | ||||
|     } | ||||
|     auto* tf_opdesc = graph_op->op_.release(); | ||||
|     if (tf_opdesc == nullptr) { | ||||
|       TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete."); | ||||
|       return; | ||||
|     } | ||||
|     for (int i = 0; i < num_inputs; ++i) { | ||||
|       auto* graph_tensor = dyncast<GraphTensor>(inputs[i]); | ||||
|       if (!graph_tensor) { | ||||
|  | ||||
| @ -169,7 +169,152 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { | ||||
|   TF_DeleteExecutionContext(eager_execution_ctx); | ||||
| } | ||||
| 
 | ||||
| TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { | ||||
| TEST_P(UnifiedCAPI, TestMultiOutputGraph) { | ||||
|   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( | ||||
|       TF_NewStatus(), TF_DeleteStatus); | ||||
|   TF_Status* s = status.get(); | ||||
| 
 | ||||
|   // Start a new function / execution context.
 | ||||
|   string fn_name = "two_adds"; | ||||
|   TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
| 
 | ||||
|   auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|   auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
| 
 | ||||
|   // Create a first "Add" computing `arg0 + arg1`.
 | ||||
|   TF_AbstractTensor* add_output1; | ||||
|   { | ||||
|     // Build an abstract operation, inputs and output.
 | ||||
|     auto* add_op = TF_NewAbstractOp(graph_ctx); | ||||
|     TF_AbstractOpSetOpType(add_op, "Add", s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_AbstractOpSetOpName(add_op, "my_add1", s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_AbstractTensor* inputs[2] = {arg0, arg1}; | ||||
|     TF_OutputList* add_outputs = TF_NewOutputList(); | ||||
|     // Trace the operation now (create a node in the graph).
 | ||||
|     TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_DeleteAbstractOp(add_op); | ||||
|     // Extract the resulting tensor.
 | ||||
|     add_output1 = TF_OutputListGet(add_outputs, 0); | ||||
|     TF_DeleteOutputList(add_outputs); | ||||
|   } | ||||
| 
 | ||||
|   // Same with a second "Add" computing `arg1 + arg1`.
 | ||||
|   TF_AbstractTensor* add_output2; | ||||
|   { | ||||
|     // Build an abstract operation, inputs and output.
 | ||||
|     auto* add_op = TF_NewAbstractOp(graph_ctx); | ||||
|     TF_AbstractOpSetOpType(add_op, "Add", s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_AbstractOpSetOpName(add_op, "my_add2", s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_AbstractTensor* inputs[2] = {arg1, arg1}; | ||||
|     TF_OutputList* add_outputs = TF_NewOutputList(); | ||||
|     // Trace the operation now (create a node in the graph).
 | ||||
|     TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_DeleteAbstractOp(add_op); | ||||
|     // Extract the resulting tensor.
 | ||||
|     add_output2 = TF_OutputListGet(add_outputs, 0); | ||||
|     TF_DeleteOutputList(add_outputs); | ||||
|   } | ||||
| 
 | ||||
|   // Finalize the function by providing the returned values.
 | ||||
|   TF_AbstractFunction* func; | ||||
|   { | ||||
|     // We want to return the output of both add operations, create a new list
 | ||||
|     // and populate it.
 | ||||
|     TF_OutputList* func_outputs = TF_NewOutputList(); | ||||
|     TF_OutputListPushBack(func_outputs, add_output1, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_OutputListPushBack(func_outputs, add_output2, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     func = TF_FinalizeFunction(graph_ctx, func_outputs, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_DeleteOutputList(func_outputs); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * We traced so far this function: | ||||
|    * | ||||
|    *   def two_adds(a, b): | ||||
|    *     my_add1 = a + b | ||||
|    *     my_add2 = b + b | ||||
|    *     return my_add1, my_add2 | ||||
|    * | ||||
|    * Now we will execute this function with an eager context: | ||||
|    * | ||||
|    *   output1, output2 = two_adds(2.0, 3.0) | ||||
|    * | ||||
|    * and check that we got 5.0 and 6.0 as results. | ||||
|    */ | ||||
| 
 | ||||
|   // Build eager context.
 | ||||
|   TFE_ContextOptions* opts = TFE_NewContextOptions(); | ||||
|   TF_ExecutionContext* eager_execution_ctx = | ||||
|       TF_NewEagerExecutionContext(opts, s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|   TFE_DeleteContextOptions(opts); | ||||
| 
 | ||||
|   TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
| 
 | ||||
|   // Build the abstract op to run the function.
 | ||||
|   TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); | ||||
|   TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
| 
 | ||||
|   // Build two abstract input tensors as function arguments.
 | ||||
|   std::vector<TF_AbstractTensor*> func_args; | ||||
|   { | ||||
|     TFE_Context* eager_ctx = | ||||
|         TF_ExecutionContextGetTFEContext(eager_execution_ctx); | ||||
|     TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); | ||||
|     func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     input_eager = TestScalarTensorHandle(eager_ctx, 3.0f); | ||||
|     func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|   } | ||||
| 
 | ||||
|   TF_OutputList* func_outputs = TF_NewOutputList(); | ||||
|   TF_OutputListSetNumOutputs(func_outputs, 2, s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|   TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs, | ||||
|                       eager_execution_ctx, s); | ||||
|   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|   TF_DeleteAbstractOp(fn_op); | ||||
|   for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t); | ||||
| 
 | ||||
|   ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs)); | ||||
|   float results[2]; | ||||
|   for (int idx = 0; idx < 2; ++idx) { | ||||
|     TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); | ||||
|     TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s); | ||||
|     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); | ||||
|     results[idx] = *static_cast<float*>(TF_TensorData(f_t)); | ||||
|     TF_DeleteTensor(f_t); | ||||
|   } | ||||
|   ASSERT_EQ(results[0], 5.0); | ||||
|   ASSERT_EQ(results[1], 6.0); | ||||
| 
 | ||||
|   for (int idx = 0; idx < 2; ++idx) { | ||||
|     TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); | ||||
|     TF_DeleteAbstractTensor(result); | ||||
|   } | ||||
|   TF_DeleteOutputList(func_outputs); | ||||
|   TF_DeleteExecutionContext(eager_execution_ctx); | ||||
|   TF_DeleteAbstractFunction(func); | ||||
| } | ||||
| 
 | ||||
| TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { | ||||
|   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( | ||||
|       TF_NewStatus(), TF_DeleteStatus); | ||||
|   TFE_ContextOptions* opts = TFE_NewContextOptions(); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user