diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 4ccff317516..a5a66d93853 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1097,7 +1097,7 @@ TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) { TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); TF_Operation* add = Add(feed1, feed2, func_graph_, s_); DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_)); EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does " "not have output 2\n\tEncountered while processing " "input 1 into function 'MyFunc'"), @@ -1134,7 +1134,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) { TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); TF_Operation* add = Add(feed1, feed2, func_graph_, s_); DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_)); EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does " "not have output 3\n\tEncountered while processing " "output 0 from function 'MyFunc'"), diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index b8d36b89472..0fe85d5d2c6 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -29,4 +29,11 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { op->node.set_requested_device(device); } +void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, + &dst.oper->node, dst.index); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index e1a55d7755a..ab71a4170bb 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -27,6 +27,9 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc index e3f6523c190..18b8be3794f 100644 --- a/tensorflow/cc/ops/while_loop_test.cc +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -146,7 +146,7 @@ TEST_F(WhileLoopTest, InvalidCondOutputIndex) { *output = {less.node(), 100}; return s.status(); }, - AddOneBody, error::INVALID_ARGUMENT, + AddOneBody, error::OUT_OF_RANGE, "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output " "100"); } @@ -182,7 +182,7 @@ TEST_F(WhileLoopTest, InvalidBodyOutputIndex) { outputs->emplace_back(add.node(), 100); return s.status(); }, - error::INVALID_ARGUMENT, + error::OUT_OF_RANGE, "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have " "output 100"); } diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 45ab38c3959..2ad0081e1fb 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -261,7 +261,6 @@ Status Node::input_node(int idx, const Node** const_n) const { return Status::OK(); } - // Graph Graph::Graph(const OpRegistryInterface* ops) @@ -420,6 +419,34 @@ void Graph::RemoveEdge(const Edge* e) { --num_edges_; } +Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, + int dst_index) { + TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); + TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); + const Edge* e = FindEdge(dst, dst_index); + if (e == nullptr) { + return errors::InvalidArgument("Couldn't find edge to ", + dst->DebugString()); + } + RemoveEdge(e); + AddEdge(new_src, new_src_index, dst, dst_index); + dst->MaybeCopyOnWrite(); + (*dst->props_->node_def.mutable_input())[dst_index] = + strings::StrCat(new_src->name(), ":", new_src_index); + return Status::OK(); +} + +const Edge* Graph::FindEdge(const Node* dst, int index) { + for (const Edge* e : edges_) { + // edges_ will contain null edges if RemoveEdge() was called. + if (e == nullptr) continue; + if (e->dst() == dst && e->dst_input() == index) { + return e; + } + } + return nullptr; +} + Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { return ops_.AddLibrary(fdef_lib); } @@ -528,10 +555,10 @@ Status Graph::IsValidNode(const Node* node) const { Status Graph::IsValidOutputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_outputs()) { - return errors::InvalidArgument("Node '", node->name(), "' (type: '", - node->op_def().name(), - "', num of outputs: ", node->num_outputs(), - ") does not have ", "output ", idx); + return errors::OutOfRange("Node '", node->name(), "' (type: '", + node->op_def().name(), + "', num of outputs: ", node->num_outputs(), + ") does not have ", "output ", idx); } return Status::OK(); } @@ -539,10 +566,10 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const { Status Graph::IsValidInputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_inputs()) { - return errors::InvalidArgument("Node '", node->name(), "' (type: '", - node->op_def().name(), - "', num of inputs: ", node->num_inputs(), - ") does not have ", "input ", idx); + return errors::OutOfRange("Node '", node->name(), "' (type: '", + node->op_def().name(), + "', num of inputs: ", node->num_inputs(), + ") does not have ", "input ", idx); } return Status::OK(); } diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 72c8d38cb91..5a31a6216b3 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -443,6 +443,11 @@ class Graph { // REQUIRES: The edge must exist. void RemoveEdge(const Edge* edge); + // Updates the input to a node. The existing edge to `dst` is removed + // and an edge from `new_src` to `dst` is created. The NodeDef associated with + // `dst` is also updated. + Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index); + // Adds the function and gradient definitions in `fdef_lib` to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same @@ -631,6 +636,10 @@ class Graph { // AddWhileContext() or Node::while_ctx(), but this manages the lifetime. std::map while_ctxs_; + // Searches through edges_ for the Edge whose destination node and index + // matches dst. An edge with destination `dst` must exist in the graph. + const Edge* FindEdge(const Node* dst, int index); + TF_DISALLOW_COPY_AND_ASSIGN(Graph); }; diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index ca77f3b44d4..85eba0e1662 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -410,6 +411,42 @@ TEST_F(GraphTest, IsValidNode) { s.error_message()); } +TEST_F(GraphTest, UpdateEdge) { + // Build a little graph + Node* a = FromNodeDef("A", "OneOutput", 0); + Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1); + Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1); + Node* d = FromNodeDef("D", "OneInput", 1); + + graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(a, graph_.sink_node()); + graph_.AddEdge(a, 0, c, 0); + + graph_.AddControlEdge(c, graph_.sink_node()); + graph_.AddEdge(c, 0, b, 0); + graph_.AddEdge(c, 1, d, 0); + + // Initial edge connections + EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_)); + + // Update the inputs, expect that Edge a to b (2->3) is now in the graph + // and c to b (4->3) no longer appears. + TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0)); + // Check that the edge is connecting the correct nodes. + EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_)); + + // Update a's 0th output again. + TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0)); + EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_)); + + // Update a's 1st output which is out of range. + Status s = graph_.UpdateEdge(a, 1, d, 0); + EXPECT_FALSE(s.ok()); + EXPECT_EQ( + s.error_message(), + "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1"); +} + TEST_F(GraphTest, InputEdges) { Node* a = FromNodeDef("A", "OneOutput", 0); Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2); diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index db9aa1e0617..d6615563aca 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1920,25 +1920,30 @@ class Operation(object): or if input tensor type is not convertible to dtype. ValueError: if the Tensor is from a different graph. """ - assert not self._graph._c_graph, ( # pylint: disable=protected-access - "Operation._update_input doesn't work with C API") if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) _assert_same_graph(self, tensor) - if dtype is None: - dtype = tensor.dtype + if _USE_C_API: + with errors.raise_exception_on_not_ok_status() as status: + c_api.UpdateEdge( + self._graph._c_graph, # pylint: disable=protected-access + tensor._as_tf_output(), # pylint: disable=protected-access + self._tf_input(index), + status) else: - dtype = dtypes.as_dtype(dtype) - if not dtype.is_compatible_with(tensor.dtype): - raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" % - (tensor.dtype.name, dtype.name)) - - self._inputs[index].consumers().remove(self) - self._inputs[index] = tensor - self._input_types_val[index] = dtype - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() + if dtype is None: + dtype = tensor.dtype + else: + dtype = dtypes.as_dtype(dtype) + if not dtype.is_compatible_with(tensor.dtype): + raise TypeError( + "Cannot convert a tensor of type %s to an input of type %s" % + (tensor.dtype.name, dtype.name)) + self._inputs[index].consumers().remove(self) + self._inputs[index] = tensor + self._input_types_val[index] = dtype + tensor._add_consumer(self) # pylint: disable=protected-access + self._recompute_node_def() def _add_control_inputs(self, ops): """Add a list of new control inputs to this operation. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 00a0d1635d4..caf24617292 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -424,6 +424,70 @@ class OperationTest(test_util.TensorFlowTestCase): "Graph is invalid, contains a cycle with 2 nodes"): sess.run(x) + @test_util.enable_c_api + def testUpdateInput(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + y = constant_op.constant(2) + z = x + y + z.op._update_input(0, y) # pylint: disable=protected-access + with session.Session(graph=g) as sess: + self.assertEquals(sess.run(z), 4) + z.op._update_input(0, x) + with session.Session(graph=g) as sess: + self.assertEquals(sess.run(z), 3) + z.op._update_input(1, y) + with session.Session(graph=g) as sess: + self.assertEquals(sess.run(z), 3) + + @test_util.enable_c_api + def testUpdateInputGraphError(self): + g_0 = ops.Graph() + g_1 = ops.Graph() + with g_0.as_default(): + x = constant_op.constant(1) + with g_1.as_default(): + y = constant_op.constant(2) + z = y * 2 + with self.assertRaisesRegexp(ValueError, "must be from the same graph"): + z.op._update_input(0, x) # pylint: disable=protected-access + + # TODO(nolivia): check the shape/type in _update_input() instead of depending + # on run to do that. + @test_util.enable_c_api + def testUpdateInputTypeError(self): + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(0) + x = constant_op.constant("") + y = constant_op.constant(1) + z = y + w + z.op._update_input(0, x) # pylint: disable=protected-access + with session.Session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Input 0 of node add was passed string from Const_1:0 incompatible " + "with expected int32"): + sess.run(z) + + # C-API throws the error differently. + def testUpdateInputOutOfRange(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + with self.assertRaises(IndexError): + x.op._update_input(1, x) # pylint: disable=protected-access + + @test_util.enable_c_api + def testUpdateInputOutOfRangeC(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + with self.assertRaisesRegexp(errors.OutOfRangeError, + "does not have input 1"): + x.op._update_input(1, x) # pylint: disable=protected-access + class CreateOpTest(test_util.TensorFlowTestCase):