implementing _update_input for the C API
PiperOrigin-RevId: 170147211
This commit is contained in:
parent
035a9be3cc
commit
c65b9f87d9
@ -1097,7 +1097,7 @@ TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
|
||||
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
||||
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
||||
DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
|
||||
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
||||
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
|
||||
EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
|
||||
"not have output 2\n\tEncountered while processing "
|
||||
"input 1 into function 'MyFunc'"),
|
||||
@ -1134,7 +1134,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
|
||||
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
||||
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
||||
DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
|
||||
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
||||
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
|
||||
EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
|
||||
"not have output 3\n\tEncountered while processing "
|
||||
"output 0 from function 'MyFunc'"),
|
||||
|
@ -29,4 +29,11 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
||||
op->node.set_requested_device(device);
|
||||
}
|
||||
|
||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
TF_Status* status) {
|
||||
mutex_lock l(graph->mu);
|
||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||
&dst.oper->node, dst.index);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,9 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
|
||||
|
||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
||||
|
||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
TF_Status* status);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
|
||||
|
@ -146,7 +146,7 @@ TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
|
||||
*output = {less.node(), 100};
|
||||
return s.status();
|
||||
},
|
||||
AddOneBody, error::INVALID_ARGUMENT,
|
||||
AddOneBody, error::OUT_OF_RANGE,
|
||||
"Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
|
||||
"100");
|
||||
}
|
||||
@ -182,7 +182,7 @@ TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
|
||||
outputs->emplace_back(add.node(), 100);
|
||||
return s.status();
|
||||
},
|
||||
error::INVALID_ARGUMENT,
|
||||
error::OUT_OF_RANGE,
|
||||
"Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
|
||||
"output 100");
|
||||
}
|
||||
|
@ -261,7 +261,6 @@ Status Node::input_node(int idx, const Node** const_n) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
// Graph
|
||||
|
||||
Graph::Graph(const OpRegistryInterface* ops)
|
||||
@ -420,6 +419,34 @@ void Graph::RemoveEdge(const Edge* e) {
|
||||
--num_edges_;
|
||||
}
|
||||
|
||||
Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
|
||||
int dst_index) {
|
||||
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
|
||||
TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
|
||||
const Edge* e = FindEdge(dst, dst_index);
|
||||
if (e == nullptr) {
|
||||
return errors::InvalidArgument("Couldn't find edge to ",
|
||||
dst->DebugString());
|
||||
}
|
||||
RemoveEdge(e);
|
||||
AddEdge(new_src, new_src_index, dst, dst_index);
|
||||
dst->MaybeCopyOnWrite();
|
||||
(*dst->props_->node_def.mutable_input())[dst_index] =
|
||||
strings::StrCat(new_src->name(), ":", new_src_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const Edge* Graph::FindEdge(const Node* dst, int index) {
|
||||
for (const Edge* e : edges_) {
|
||||
// edges_ will contain null edges if RemoveEdge() was called.
|
||||
if (e == nullptr) continue;
|
||||
if (e->dst() == dst && e->dst_input() == index) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
|
||||
return ops_.AddLibrary(fdef_lib);
|
||||
}
|
||||
@ -528,10 +555,10 @@ Status Graph::IsValidNode(const Node* node) const {
|
||||
Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
|
||||
TF_RETURN_IF_ERROR(IsValidNode(node));
|
||||
if (idx >= node->num_outputs()) {
|
||||
return errors::InvalidArgument("Node '", node->name(), "' (type: '",
|
||||
node->op_def().name(),
|
||||
"', num of outputs: ", node->num_outputs(),
|
||||
") does not have ", "output ", idx);
|
||||
return errors::OutOfRange("Node '", node->name(), "' (type: '",
|
||||
node->op_def().name(),
|
||||
"', num of outputs: ", node->num_outputs(),
|
||||
") does not have ", "output ", idx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -539,10 +566,10 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
|
||||
Status Graph::IsValidInputTensor(const Node* node, int idx) const {
|
||||
TF_RETURN_IF_ERROR(IsValidNode(node));
|
||||
if (idx >= node->num_inputs()) {
|
||||
return errors::InvalidArgument("Node '", node->name(), "' (type: '",
|
||||
node->op_def().name(),
|
||||
"', num of inputs: ", node->num_inputs(),
|
||||
") does not have ", "input ", idx);
|
||||
return errors::OutOfRange("Node '", node->name(), "' (type: '",
|
||||
node->op_def().name(),
|
||||
"', num of inputs: ", node->num_inputs(),
|
||||
") does not have ", "input ", idx);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -443,6 +443,11 @@ class Graph {
|
||||
// REQUIRES: The edge must exist.
|
||||
void RemoveEdge(const Edge* edge);
|
||||
|
||||
// Updates the input to a node. The existing edge to `dst` is removed
|
||||
// and an edge from `new_src` to `dst` is created. The NodeDef associated with
|
||||
// `dst` is also updated.
|
||||
Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
|
||||
|
||||
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
|
||||
// registry. Ignores duplicate functions, and returns a bad status if an
|
||||
// imported function differs from an existing function or op with the same
|
||||
@ -631,6 +636,10 @@ class Graph {
|
||||
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
|
||||
std::map<string, WhileContext> while_ctxs_;
|
||||
|
||||
// Searches through edges_ for the Edge whose destination node and index
|
||||
// matches dst. An edge with destination `dst` must exist in the graph.
|
||||
const Edge* FindEdge(const Node* dst, int index);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
|
||||
};
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -410,6 +411,42 @@ TEST_F(GraphTest, IsValidNode) {
|
||||
s.error_message());
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, UpdateEdge) {
|
||||
// Build a little graph
|
||||
Node* a = FromNodeDef("A", "OneOutput", 0);
|
||||
Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
|
||||
Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
|
||||
Node* d = FromNodeDef("D", "OneInput", 1);
|
||||
|
||||
graph_.AddControlEdge(graph_.source_node(), a);
|
||||
graph_.AddControlEdge(a, graph_.sink_node());
|
||||
graph_.AddEdge(a, 0, c, 0);
|
||||
|
||||
graph_.AddControlEdge(c, graph_.sink_node());
|
||||
graph_.AddEdge(c, 0, b, 0);
|
||||
graph_.AddEdge(c, 1, d, 0);
|
||||
|
||||
// Initial edge connections
|
||||
EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
|
||||
|
||||
// Update the inputs, expect that Edge a to b (2->3) is now in the graph
|
||||
// and c to b (4->3) no longer appears.
|
||||
TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
|
||||
// Check that the edge is connecting the correct nodes.
|
||||
EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
|
||||
|
||||
// Update a's 0th output again.
|
||||
TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
|
||||
EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
|
||||
|
||||
// Update a's 1st output which is out of range.
|
||||
Status s = graph_.UpdateEdge(a, 1, d, 0);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(
|
||||
s.error_message(),
|
||||
"Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, InputEdges) {
|
||||
Node* a = FromNodeDef("A", "OneOutput", 0);
|
||||
Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
|
||||
|
@ -1920,25 +1920,30 @@ class Operation(object):
|
||||
or if input tensor type is not convertible to dtype.
|
||||
ValueError: if the Tensor is from a different graph.
|
||||
"""
|
||||
assert not self._graph._c_graph, ( # pylint: disable=protected-access
|
||||
"Operation._update_input doesn't work with C API")
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise TypeError("tensor must be a Tensor: %s" % tensor)
|
||||
_assert_same_graph(self, tensor)
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
if _USE_C_API:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.UpdateEdge(
|
||||
self._graph._c_graph, # pylint: disable=protected-access
|
||||
tensor._as_tf_output(), # pylint: disable=protected-access
|
||||
self._tf_input(index),
|
||||
status)
|
||||
else:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_compatible_with(tensor.dtype):
|
||||
raise TypeError(
|
||||
"Cannot convert a tensor of type %s to an input of type %s" %
|
||||
(tensor.dtype.name, dtype.name))
|
||||
|
||||
self._inputs[index].consumers().remove(self)
|
||||
self._inputs[index] = tensor
|
||||
self._input_types_val[index] = dtype
|
||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||
self._recompute_node_def()
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
else:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_compatible_with(tensor.dtype):
|
||||
raise TypeError(
|
||||
"Cannot convert a tensor of type %s to an input of type %s" %
|
||||
(tensor.dtype.name, dtype.name))
|
||||
self._inputs[index].consumers().remove(self)
|
||||
self._inputs[index] = tensor
|
||||
self._input_types_val[index] = dtype
|
||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||
self._recompute_node_def()
|
||||
|
||||
def _add_control_inputs(self, ops):
|
||||
"""Add a list of new control inputs to this operation.
|
||||
|
@ -424,6 +424,70 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
"Graph is invalid, contains a cycle with 2 nodes"):
|
||||
sess.run(x)
|
||||
|
||||
@test_util.enable_c_api
|
||||
def testUpdateInput(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = constant_op.constant(1)
|
||||
y = constant_op.constant(2)
|
||||
z = x + y
|
||||
z.op._update_input(0, y) # pylint: disable=protected-access
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 4)
|
||||
z.op._update_input(0, x)
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
z.op._update_input(1, y)
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
|
||||
@test_util.enable_c_api
|
||||
def testUpdateInputGraphError(self):
|
||||
g_0 = ops.Graph()
|
||||
g_1 = ops.Graph()
|
||||
with g_0.as_default():
|
||||
x = constant_op.constant(1)
|
||||
with g_1.as_default():
|
||||
y = constant_op.constant(2)
|
||||
z = y * 2
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||
|
||||
# TODO(nolivia): check the shape/type in _update_input() instead of depending
|
||||
# on run to do that.
|
||||
@test_util.enable_c_api
|
||||
def testUpdateInputTypeError(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
w = constant_op.constant(0)
|
||||
x = constant_op.constant("")
|
||||
y = constant_op.constant(1)
|
||||
z = y + w
|
||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||
with session.Session(graph=g) as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Input 0 of node add was passed string from Const_1:0 incompatible "
|
||||
"with expected int32"):
|
||||
sess.run(z)
|
||||
|
||||
# C-API throws the error differently.
|
||||
def testUpdateInputOutOfRange(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaises(IndexError):
|
||||
x.op._update_input(1, x) # pylint: disable=protected-access
|
||||
|
||||
@test_util.enable_c_api
|
||||
def testUpdateInputOutOfRangeC(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaisesRegexp(errors.OutOfRangeError,
|
||||
"does not have input 1"):
|
||||
x.op._update_input(1, x) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class CreateOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user