From c65b9f87d91f51a233cb649f4d1a5b5f63a4d5e1 Mon Sep 17 00:00:00 2001
From: Olivia Nordquist <nolivia@google.com>
Date: Tue, 26 Sep 2017 19:56:26 -0700
Subject: [PATCH] implementing _update_input for the C API

PiperOrigin-RevId: 170147211
---
 tensorflow/c/c_api_function_test.cc     |  4 +-
 tensorflow/c/python_api.cc              |  7 +++
 tensorflow/c/python_api.h               |  3 ++
 tensorflow/cc/ops/while_loop_test.cc    |  4 +-
 tensorflow/core/graph/graph.cc          | 45 +++++++++++++----
 tensorflow/core/graph/graph.h           |  9 ++++
 tensorflow/core/graph/graph_test.cc     | 37 ++++++++++++++
 tensorflow/python/framework/ops.py      | 35 ++++++++------
 tensorflow/python/framework/ops_test.py | 64 +++++++++++++++++++++++++
 9 files changed, 180 insertions(+), 28 deletions(-)

diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 4ccff317516..a5a66d93853 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -1097,7 +1097,7 @@ TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
   DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
-  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+  EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
   EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
                    "not have output 2\n\tEncountered while processing "
                    "input 1 into function 'MyFunc'"),
@@ -1134,7 +1134,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
-  EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+  EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
   EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
                    "not have output 3\n\tEncountered while processing "
                    "output 0 from function 'MyFunc'"),
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index b8d36b89472..0fe85d5d2c6 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -29,4 +29,11 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
   op->node.set_requested_device(device);
 }
 
+void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
+                TF_Status* status) {
+  mutex_lock l(graph->mu);
+  status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
+                                           &dst.oper->node, dst.index);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index e1a55d7755a..ab71a4170bb 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -27,6 +27,9 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
 
 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
 
+void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
+                TF_Status* status);
+
 }  // namespace tensorflow
 
 #endif  // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc
index e3f6523c190..18b8be3794f 100644
--- a/tensorflow/cc/ops/while_loop_test.cc
+++ b/tensorflow/cc/ops/while_loop_test.cc
@@ -146,7 +146,7 @@ TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
         *output = {less.node(), 100};
         return s.status();
       },
-      AddOneBody, error::INVALID_ARGUMENT,
+      AddOneBody, error::OUT_OF_RANGE,
       "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
       "100");
 }
@@ -182,7 +182,7 @@ TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
                outputs->emplace_back(add.node(), 100);
                return s.status();
              },
-             error::INVALID_ARGUMENT,
+             error::OUT_OF_RANGE,
              "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
              "output 100");
 }
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 45ab38c3959..2ad0081e1fb 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -261,7 +261,6 @@ Status Node::input_node(int idx, const Node** const_n) const {
   return Status::OK();
 }
 
-
 // Graph
 
 Graph::Graph(const OpRegistryInterface* ops)
@@ -420,6 +419,34 @@ void Graph::RemoveEdge(const Edge* e) {
   --num_edges_;
 }
 
+Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
+                         int dst_index) {
+  TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
+  TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
+  const Edge* e = FindEdge(dst, dst_index);
+  if (e == nullptr) {
+    return errors::InvalidArgument("Couldn't find edge to ",
+                                   dst->DebugString());
+  }
+  RemoveEdge(e);
+  AddEdge(new_src, new_src_index, dst, dst_index);
+  dst->MaybeCopyOnWrite();
+  (*dst->props_->node_def.mutable_input())[dst_index] =
+      strings::StrCat(new_src->name(), ":", new_src_index);
+  return Status::OK();
+}
+
+const Edge* Graph::FindEdge(const Node* dst, int index) {
+  for (const Edge* e : edges_) {
+    // edges_ will contain null edges if RemoveEdge() was called.
+    if (e == nullptr) continue;
+    if (e->dst() == dst && e->dst_input() == index) {
+      return e;
+    }
+  }
+  return nullptr;
+}
+
 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
   return ops_.AddLibrary(fdef_lib);
 }
@@ -528,10 +555,10 @@ Status Graph::IsValidNode(const Node* node) const {
 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
   TF_RETURN_IF_ERROR(IsValidNode(node));
   if (idx >= node->num_outputs()) {
-    return errors::InvalidArgument("Node '", node->name(), "' (type: '",
-                                   node->op_def().name(),
-                                   "', num of outputs: ", node->num_outputs(),
-                                   ") does not have ", "output ", idx);
+    return errors::OutOfRange("Node '", node->name(), "' (type: '",
+                              node->op_def().name(),
+                              "', num of outputs: ", node->num_outputs(),
+                              ") does not have ", "output ", idx);
   }
   return Status::OK();
 }
@@ -539,10 +566,10 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
   TF_RETURN_IF_ERROR(IsValidNode(node));
   if (idx >= node->num_inputs()) {
-    return errors::InvalidArgument("Node '", node->name(), "' (type: '",
-                                   node->op_def().name(),
-                                   "', num of inputs: ", node->num_inputs(),
-                                   ") does not have ", "input ", idx);
+    return errors::OutOfRange("Node '", node->name(), "' (type: '",
+                              node->op_def().name(),
+                              "', num of inputs: ", node->num_inputs(),
+                              ") does not have ", "input ", idx);
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 72c8d38cb91..5a31a6216b3 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -443,6 +443,11 @@ class Graph {
   // REQUIRES: The edge must exist.
   void RemoveEdge(const Edge* edge);
 
+  // Updates the input to a node.  The existing edge to `dst` is removed
+  // and an edge from `new_src` to `dst` is created. The NodeDef associated with
+  // `dst` is also updated.
+  Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
+
   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
   // registry. Ignores duplicate functions, and returns a bad status if an
   // imported function differs from an existing function or op with the same
@@ -631,6 +636,10 @@ class Graph {
   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
   std::map<string, WhileContext> while_ctxs_;
 
+  // Searches through edges_ for the Edge whose destination node and index
+  // matches dst. An edge with destination `dst` must exist in the graph.
+  const Edge* FindEdge(const Node* dst, int index);
+
   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
 };
 
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
index ca77f3b44d4..85eba0e1662 100644
--- a/tensorflow/core/graph/graph_test.cc
+++ b/tensorflow/core/graph/graph_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include <set>
 #include <vector>
+#include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/node_builder.h"
@@ -410,6 +411,42 @@ TEST_F(GraphTest, IsValidNode) {
             s.error_message());
 }
 
+TEST_F(GraphTest, UpdateEdge) {
+  // Build a little graph
+  Node* a = FromNodeDef("A", "OneOutput", 0);
+  Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
+  Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
+  Node* d = FromNodeDef("D", "OneInput", 1);
+
+  graph_.AddControlEdge(graph_.source_node(), a);
+  graph_.AddControlEdge(a, graph_.sink_node());
+  graph_.AddEdge(a, 0, c, 0);
+
+  graph_.AddControlEdge(c, graph_.sink_node());
+  graph_.AddEdge(c, 0, b, 0);
+  graph_.AddEdge(c, 1, d, 0);
+
+  // Initial edge connections
+  EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
+
+  // Update the inputs, expect that Edge a to b (2->3) is now in the graph
+  // and c to b (4->3) no longer appears.
+  TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
+  // Check that the edge is connecting the correct nodes.
+  EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
+
+  // Update a's 0th output again.
+  TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
+  EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
+
+  // Update a's 1st output which is out of range.
+  Status s = graph_.UpdateEdge(a, 1, d, 0);
+  EXPECT_FALSE(s.ok());
+  EXPECT_EQ(
+      s.error_message(),
+      "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
+}
+
 TEST_F(GraphTest, InputEdges) {
   Node* a = FromNodeDef("A", "OneOutput", 0);
   Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index db9aa1e0617..d6615563aca 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1920,25 +1920,30 @@ class Operation(object):
         or if input tensor type is not convertible to dtype.
       ValueError: if the Tensor is from a different graph.
     """
-    assert not self._graph._c_graph, (  # pylint: disable=protected-access
-        "Operation._update_input doesn't work with C API")
     if not isinstance(tensor, Tensor):
       raise TypeError("tensor must be a Tensor: %s" % tensor)
     _assert_same_graph(self, tensor)
-    if dtype is None:
-      dtype = tensor.dtype
+    if _USE_C_API:
+      with errors.raise_exception_on_not_ok_status() as status:
+        c_api.UpdateEdge(
+            self._graph._c_graph,  # pylint: disable=protected-access
+            tensor._as_tf_output(),  # pylint: disable=protected-access
+            self._tf_input(index),
+            status)
     else:
-      dtype = dtypes.as_dtype(dtype)
-      if not dtype.is_compatible_with(tensor.dtype):
-        raise TypeError(
-            "Cannot convert a tensor of type %s to an input of type %s" %
-            (tensor.dtype.name, dtype.name))
-
-    self._inputs[index].consumers().remove(self)
-    self._inputs[index] = tensor
-    self._input_types_val[index] = dtype
-    tensor._add_consumer(self)  # pylint: disable=protected-access
-    self._recompute_node_def()
+      if dtype is None:
+        dtype = tensor.dtype
+      else:
+        dtype = dtypes.as_dtype(dtype)
+        if not dtype.is_compatible_with(tensor.dtype):
+          raise TypeError(
+              "Cannot convert a tensor of type %s to an input of type %s" %
+              (tensor.dtype.name, dtype.name))
+      self._inputs[index].consumers().remove(self)
+      self._inputs[index] = tensor
+      self._input_types_val[index] = dtype
+      tensor._add_consumer(self)  # pylint: disable=protected-access
+      self._recompute_node_def()
 
   def _add_control_inputs(self, ops):
     """Add a list of new control inputs to this operation.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 00a0d1635d4..caf24617292 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -424,6 +424,70 @@ class OperationTest(test_util.TensorFlowTestCase):
           "Graph is invalid, contains a cycle with 2 nodes"):
         sess.run(x)
 
+  @test_util.enable_c_api
+  def testUpdateInput(self):
+    g = ops.Graph()
+    with g.as_default():
+      x = constant_op.constant(1)
+      y = constant_op.constant(2)
+      z = x + y
+      z.op._update_input(0, y)  # pylint: disable=protected-access
+    with session.Session(graph=g) as sess:
+      self.assertEquals(sess.run(z), 4)
+    z.op._update_input(0, x)
+    with session.Session(graph=g) as sess:
+      self.assertEquals(sess.run(z), 3)
+    z.op._update_input(1, y)
+    with session.Session(graph=g) as sess:
+      self.assertEquals(sess.run(z), 3)
+
+  @test_util.enable_c_api
+  def testUpdateInputGraphError(self):
+    g_0 = ops.Graph()
+    g_1 = ops.Graph()
+    with g_0.as_default():
+      x = constant_op.constant(1)
+    with g_1.as_default():
+      y = constant_op.constant(2)
+      z = y * 2
+      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
+        z.op._update_input(0, x)  # pylint: disable=protected-access
+
+  # TODO(nolivia): check the shape/type in _update_input() instead of depending
+  # on run to do that.
+  @test_util.enable_c_api
+  def testUpdateInputTypeError(self):
+    g = ops.Graph()
+    with g.as_default():
+      w = constant_op.constant(0)
+      x = constant_op.constant("")
+      y = constant_op.constant(1)
+      z = y + w
+      z.op._update_input(0, x)  # pylint: disable=protected-access
+    with session.Session(graph=g) as sess:
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          "Input 0 of node add was passed string from Const_1:0 incompatible "
+          "with expected int32"):
+        sess.run(z)
+
+  # C-API throws the error differently.
+  def testUpdateInputOutOfRange(self):
+    g = ops.Graph()
+    with g.as_default():
+      x = constant_op.constant(1)
+    with self.assertRaises(IndexError):
+      x.op._update_input(1, x)  # pylint: disable=protected-access
+
+  @test_util.enable_c_api
+  def testUpdateInputOutOfRangeC(self):
+    g = ops.Graph()
+    with g.as_default():
+      x = constant_op.constant(1)
+    with self.assertRaisesRegexp(errors.OutOfRangeError,
+                                 "does not have input 1"):
+      x.op._update_input(1, x)  # pylint: disable=protected-access
+
 
 class CreateOpTest(test_util.TensorFlowTestCase):