Add UpdateEdge unit test in c_api_test.cc.
This commit is contained in:
parent
6ed5f2c807
commit
929b0feb8c
@ -1524,6 +1524,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
|||||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||||
const char* name, TF_Status* status);
|
const char* name, TF_Status* status);
|
||||||
|
|
||||||
|
// Update edge, switch input/ output in a node
|
||||||
|
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph,
|
||||||
|
TF_Output new_src,
|
||||||
|
TF_Input dst,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// In-process TensorFlow server functionality, for use in distributed training.
|
// In-process TensorFlow server functionality, for use in distributed training.
|
||||||
// A Server instance encapsulates a set of devices and a Session target that
|
// A Server instance encapsulates a set of devices and a Session target that
|
||||||
@ -1573,12 +1579,6 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
|
|||||||
TF_CAPI_EXPORT extern void TF_RegisterLogListener(
|
TF_CAPI_EXPORT extern void TF_RegisterLogListener(
|
||||||
void (*listener)(const char*));
|
void (*listener)(const char*));
|
||||||
|
|
||||||
// Update edge, switch input/ output in a node
|
|
||||||
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph,
|
|
||||||
TF_Output new_src,
|
|
||||||
TF_Input dst,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
|
|||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, UpdateEdge) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// Make two scalar constants.
|
||||||
|
TF_Operation* one = ScalarConst(1, graph, s, "one");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
TF_Operation* two = ScalarConst(2, graph, s, "two");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Add oper.
|
||||||
|
TF_Operation* add = Add(one, two, graph, s, "add");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Add another oper to the graph.
|
||||||
|
TF_Operation* neg = Neg(add, graph, s, "neg");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
NodeDef node_def_neg;
|
||||||
|
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||||
|
EXPECT_EQ(string("add"), node_def_neg.input(0));
|
||||||
|
|
||||||
|
// update edge of neg
|
||||||
|
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
|
||||||
|
|
||||||
|
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||||
|
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
TF_DeleteGraph(graph);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TODO(skyewm): this test currently DCHECKs, change to bad status
|
TODO(skyewm): this test currently DCHECKs, change to bad status
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user