From f25517d5b25b322735e1aea6231e286d6b596211 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 20 Sep 2020 10:26:57 -0500 Subject: [PATCH 1/3] Add TF_UpdateEdge C API. --- tensorflow/c/c_api.cc | 42 ++++++++++++++++++++++++++++++++++++++++++ tensorflow/c/c_api.h | 6 ++++++ 2 files changed, 48 insertions(+) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 2e1759ecea0..3ee3d50195f 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { return ret; } +void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } + status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, + &dst.oper->node, dst.index); + + if (TF_GetCode(status) == TF_OK) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + // TF_Server functions ---------------------------------------------- #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 0b4d9993e4d..e573225410f 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1573,6 +1573,12 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); TF_CAPI_EXPORT extern void TF_RegisterLogListener( void (*listener)(const char*)); +// Update edge, switch input/ output in a node +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, + TF_Output new_src, + TF_Input dst, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif From 6ed5f2c80780f858e0f2fa622561ed7724af3b6e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 3 Oct 2020 07:47:19 -0500 Subject: [PATCH 2/3] Call TF_UpdateEdge in python_api.cc. --- tensorflow/c/python_api.cc | 38 +------------------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 8d7a0cd3a18..ba6d22ba228 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(&new_src.oper->node); - - if (ic->num_outputs() <= new_src.index) { - status->status = tensorflow::errors::OutOfRange( - "Cannot update edge. Output index [", new_src.index, - "] is greater than the number of total outputs [", ic->num_outputs(), - "]."); - return; - } - tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); - - tensorflow::shape_inference::InferenceContext* ic_dst = - graph->refiner.GetContext(&dst.oper->node); - if (ic_dst->num_inputs() <= dst.index) { - status->status = tensorflow::errors::OutOfRange( - "Cannot update edge. Input index [", dst.index, - "] is greater than the number of total inputs [", ic_dst->num_inputs(), - "]."); - return; - } - if (!ic_dst->MergeInput(dst.index, shape)) { - status->status = tensorflow::errors::InvalidArgument( - "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), - " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); - return; - } - status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, - &dst.oper->node, dst.index); - - if (TF_GetCode(status) == TF_OK) { - // This modification only updates the destination node for - // the purposes of running this graph in a session. Thus, we don't - // record the source node as being modified. - RecordMutation(graph, *dst.oper, "updating input tensor"); - } + TF_UpdateEdge(graph, new_src, dst, status); } void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { From 929b0feb8c84d93b89527b3b9544b69734cc2637 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 3 Oct 2020 07:51:47 -0500 Subject: [PATCH 3/3] Add UpdateEdge unit test in c_api_test.cc. --- tensorflow/c/c_api.h | 12 ++++++------ tensorflow/c/c_api_test.cc | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index e573225410f..fce1d77c3ec 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1524,6 +1524,12 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// Update edge, switch input/ output in a node +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, + TF_Output new_src, + TF_Input dst, + TF_Status* status); + // -------------------------------------------------------------------------- // In-process TensorFlow server functionality, for use in distributed training. // A Server instance encapsulates a set of devices and a Session target that @@ -1573,12 +1579,6 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); TF_CAPI_EXPORT extern void TF_RegisterLogListener( void (*listener)(const char*)); -// Update edge, switch input/ output in a node -TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, - TF_Output new_src, - TF_Input dst, - TF_Status* status); - #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bbbbb8f7d56..fc1fdccee16 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -634,6 +634,40 @@ TEST(CAPI, Graph) { TF_DeleteStatus(s); } +TEST(CAPI, UpdateEdge) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Make two scalar constants. + TF_Operation* one = ScalarConst(1, graph, s, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* two = ScalarConst(2, graph, s, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add oper. + TF_Operation* add = Add(one, two, graph, s, "add"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add another oper to the graph. + TF_Operation* neg = Neg(add, graph, s, "neg"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + NodeDef node_def_neg; + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("add"), node_def_neg.input(0)); + + // update edge of neg + TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s); + + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("one:0"), node_def_neg.input(0)); + + // Clean up + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + /* TODO(skyewm): this test currently DCHECKs, change to bad status