diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index e573225410f..fce1d77c3ec 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1524,6 +1524,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// Update edge, switch input/ output in a node +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, + TF_Output new_src, + TF_Input dst, + TF_Status* status); + // -------------------------------------------------------------------------- // In-process TensorFlow server functionality, for use in distributed training. // A Server instance encapsulates a set of devices and a Session target that @@ -1573,12 +1579,6 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); TF_CAPI_EXPORT extern void TF_RegisterLogListener( void (*listener)(const char*)); -// Update edge, switch input/ output in a node -TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, - TF_Output new_src, - TF_Input dst, - TF_Status* status); - #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bbbbb8f7d56..fc1fdccee16 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -634,6 +634,40 @@ TEST(CAPI, Graph) { TF_DeleteStatus(s); } +TEST(CAPI, UpdateEdge) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Make two scalar constants. + TF_Operation* one = ScalarConst(1, graph, s, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* two = ScalarConst(2, graph, s, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add oper. + TF_Operation* add = Add(one, two, graph, s, "add"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add another oper to the graph. + TF_Operation* neg = Neg(add, graph, s, "neg"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + NodeDef node_def_neg; + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("add"), node_def_neg.input(0)); + + // update edge of neg + TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s); + + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("one:0"), node_def_neg.input(0)); + + // Clean up + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + /* TODO(skyewm): this test currently DCHECKs, change to bad status