Merge pull request #43383 from SciSharp:master

PiperOrigin-RevId: 336204525
Change-Id: I570854338a7bec9c4e349c411868fc2b9cfd4b9b
This commit is contained in:
TensorFlower Gardener 2020-10-08 18:22:49 -07:00
commit 683c0f87c7
4 changed files with 81 additions and 37 deletions

View File

@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
return ret; return ret;
} }
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
using tensorflow::RecordMutation;
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
}
// TF_Server functions ---------------------------------------------- // TF_Server functions ----------------------------------------------
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

View File

@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
const char* name, TF_Status* status); const char* name, TF_Status* status);
// Update edge, switch input/ output in a node
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
TF_Input dst, TF_Status* status);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// In-process TensorFlow server functionality, for use in distributed training. // In-process TensorFlow server functionality, for use in distributed training.
// A Server instance encapsulates a set of devices and a Session target that // A Server instance encapsulates a set of devices and a Session target that

View File

@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s); TF_DeleteStatus(s);
} }
TEST(CAPI, UpdateEdge) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Make two scalar constants.
TF_Operation* one = ScalarConst(1, graph, s, "one");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* two = ScalarConst(2, graph, s, "two");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add oper.
TF_Operation* add = Add(one, two, graph, s, "add");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add another oper to the graph.
TF_Operation* neg = Neg(add, graph, s, "neg");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
NodeDef node_def_neg;
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("add"), node_def_neg.input(0));
// update edge of neg
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
// Clean up
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
/* /*
TODO(skyewm): this test currently DCHECKs, change to bad status TODO(skyewm): this test currently DCHECKs, change to bad status

View File

@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) { TF_Status* status) {
mutex_lock l(graph->mu); TF_UpdateEdge(graph, new_src, dst, status);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
} }
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {