implementing _update_input for the C API

PiperOrigin-RevId: 170147211
This commit is contained in:
Olivia Nordquist 2017-09-26 19:56:26 -07:00 committed by TensorFlower Gardener
parent 035a9be3cc
commit c65b9f87d9
9 changed files with 180 additions and 28 deletions

View File

@ -1097,7 +1097,7 @@ TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
"not have output 2\n\tEncountered while processing "
"input 1 into function 'MyFunc'"),
@ -1134,7 +1134,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
"not have output 3\n\tEncountered while processing "
"output 0 from function 'MyFunc'"),

View File

@ -29,4 +29,11 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
op->node.set_requested_device(device);
}
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
mutex_lock l(graph->mu);
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
}
} // namespace tensorflow

View File

@ -27,6 +27,9 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_

View File

@ -146,7 +146,7 @@ TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
*output = {less.node(), 100};
return s.status();
},
AddOneBody, error::INVALID_ARGUMENT,
AddOneBody, error::OUT_OF_RANGE,
"Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
"100");
}
@ -182,7 +182,7 @@ TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
outputs->emplace_back(add.node(), 100);
return s.status();
},
error::INVALID_ARGUMENT,
error::OUT_OF_RANGE,
"Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
"output 100");
}

View File

@ -261,7 +261,6 @@ Status Node::input_node(int idx, const Node** const_n) const {
return Status::OK();
}
// Graph
Graph::Graph(const OpRegistryInterface* ops)
@ -420,6 +419,34 @@ void Graph::RemoveEdge(const Edge* e) {
--num_edges_;
}
Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
int dst_index) {
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
const Edge* e = FindEdge(dst, dst_index);
if (e == nullptr) {
return errors::InvalidArgument("Couldn't find edge to ",
dst->DebugString());
}
RemoveEdge(e);
AddEdge(new_src, new_src_index, dst, dst_index);
dst->MaybeCopyOnWrite();
(*dst->props_->node_def.mutable_input())[dst_index] =
strings::StrCat(new_src->name(), ":", new_src_index);
return Status::OK();
}
const Edge* Graph::FindEdge(const Node* dst, int index) {
for (const Edge* e : edges_) {
// edges_ will contain null edges if RemoveEdge() was called.
if (e == nullptr) continue;
if (e->dst() == dst && e->dst_input() == index) {
return e;
}
}
return nullptr;
}
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
return ops_.AddLibrary(fdef_lib);
}
@ -528,10 +555,10 @@ Status Graph::IsValidNode(const Node* node) const {
Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
if (idx >= node->num_outputs()) {
return errors::InvalidArgument("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of outputs: ", node->num_outputs(),
") does not have ", "output ", idx);
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of outputs: ", node->num_outputs(),
") does not have ", "output ", idx);
}
return Status::OK();
}
@ -539,10 +566,10 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
Status Graph::IsValidInputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
if (idx >= node->num_inputs()) {
return errors::InvalidArgument("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of inputs: ", node->num_inputs(),
") does not have ", "input ", idx);
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of inputs: ", node->num_inputs(),
") does not have ", "input ", idx);
}
return Status::OK();
}

View File

@ -443,6 +443,11 @@ class Graph {
// REQUIRES: The edge must exist.
void RemoveEdge(const Edge* edge);
// Updates the input to a node. The existing edge to `dst` is removed
// and an edge from `new_src` to `dst` is created. The NodeDef associated with
// `dst` is also updated.
Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
// registry. Ignores duplicate functions, and returns a bad status if an
// imported function differs from an existing function or op with the same
@ -631,6 +636,10 @@ class Graph {
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
std::map<string, WhileContext> while_ctxs_;
// Searches through edges_ for the Edge whose destination node and index
// matches dst. An edge with destination `dst` must exist in the graph.
const Edge* FindEdge(const Node* dst, int index);
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
@ -410,6 +411,42 @@ TEST_F(GraphTest, IsValidNode) {
s.error_message());
}
TEST_F(GraphTest, UpdateEdge) {
// Build a little graph
Node* a = FromNodeDef("A", "OneOutput", 0);
Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
Node* d = FromNodeDef("D", "OneInput", 1);
graph_.AddControlEdge(graph_.source_node(), a);
graph_.AddControlEdge(a, graph_.sink_node());
graph_.AddEdge(a, 0, c, 0);
graph_.AddControlEdge(c, graph_.sink_node());
graph_.AddEdge(c, 0, b, 0);
graph_.AddEdge(c, 1, d, 0);
// Initial edge connections
EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
// Update the inputs, expect that Edge a to b (2->3) is now in the graph
// and c to b (4->3) no longer appears.
TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
// Check that the edge is connecting the correct nodes.
EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
// Update a's 0th output again.
TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
// Update a's 1st output which is out of range.
Status s = graph_.UpdateEdge(a, 1, d, 0);
EXPECT_FALSE(s.ok());
EXPECT_EQ(
s.error_message(),
"Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
}
TEST_F(GraphTest, InputEdges) {
Node* a = FromNodeDef("A", "OneOutput", 0);
Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);

View File

@ -1920,25 +1920,30 @@ class Operation(object):
or if input tensor type is not convertible to dtype.
ValueError: if the Tensor is from a different graph.
"""
assert not self._graph._c_graph, ( # pylint: disable=protected-access
"Operation._update_input doesn't work with C API")
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
if dtype is None:
dtype = tensor.dtype
if _USE_C_API:
with errors.raise_exception_on_not_ok_status() as status:
c_api.UpdateEdge(
self._graph._c_graph, # pylint: disable=protected-access
tensor._as_tf_output(), # pylint: disable=protected-access
self._tf_input(index),
status)
else:
dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype):
raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s" %
(tensor.dtype.name, dtype.name))
self._inputs[index].consumers().remove(self)
self._inputs[index] = tensor
self._input_types_val[index] = dtype
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
if dtype is None:
dtype = tensor.dtype
else:
dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype):
raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s" %
(tensor.dtype.name, dtype.name))
self._inputs[index].consumers().remove(self)
self._inputs[index] = tensor
self._input_types_val[index] = dtype
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
def _add_control_inputs(self, ops):
"""Add a list of new control inputs to this operation.

View File

@ -424,6 +424,70 @@ class OperationTest(test_util.TensorFlowTestCase):
"Graph is invalid, contains a cycle with 2 nodes"):
sess.run(x)
@test_util.enable_c_api
def testUpdateInput(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
y = constant_op.constant(2)
z = x + y
z.op._update_input(0, y) # pylint: disable=protected-access
with session.Session(graph=g) as sess:
self.assertEquals(sess.run(z), 4)
z.op._update_input(0, x)
with session.Session(graph=g) as sess:
self.assertEquals(sess.run(z), 3)
z.op._update_input(1, y)
with session.Session(graph=g) as sess:
self.assertEquals(sess.run(z), 3)
@test_util.enable_c_api
def testUpdateInputGraphError(self):
g_0 = ops.Graph()
g_1 = ops.Graph()
with g_0.as_default():
x = constant_op.constant(1)
with g_1.as_default():
y = constant_op.constant(2)
z = y * 2
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
z.op._update_input(0, x) # pylint: disable=protected-access
# TODO(nolivia): check the shape/type in _update_input() instead of depending
# on run to do that.
@test_util.enable_c_api
def testUpdateInputTypeError(self):
g = ops.Graph()
with g.as_default():
w = constant_op.constant(0)
x = constant_op.constant("")
y = constant_op.constant(1)
z = y + w
z.op._update_input(0, x) # pylint: disable=protected-access
with session.Session(graph=g) as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Input 0 of node add was passed string from Const_1:0 incompatible "
"with expected int32"):
sess.run(z)
# C-API throws the error differently.
def testUpdateInputOutOfRange(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
with self.assertRaises(IndexError):
x.op._update_input(1, x) # pylint: disable=protected-access
@test_util.enable_c_api
def testUpdateInputOutOfRangeC(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
with self.assertRaisesRegexp(errors.OutOfRangeError,
"does not have input 1"):
x.op._update_input(1, x) # pylint: disable=protected-access
class CreateOpTest(test_util.TensorFlowTestCase):