Validate shapes when updating edges from Python.
Uses MergeInput from shape_inference to check if the new input is compatible with the preexisting shape. Also this changes the MergeInput method. Previously, MergeInput would only return true if the shapes differed *and* the merge was successful. Now, MergeInput returns true only if the merge is successful. PiperOrigin-RevId: 175576173
This commit is contained in:
parent
90222dd7b2
commit
bac56b37be
@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||
TF_Status* status) {
|
||||
mutex_lock l(graph->mu);
|
||||
tensorflow::shape_inference::InferenceContext* ic =
|
||||
graph->refiner.GetContext(&new_src.oper->node);
|
||||
|
||||
if (ic->num_outputs() <= new_src.index) {
|
||||
status->status = tensorflow::errors::OutOfRange(
|
||||
"Cannot update edge. Output index [", new_src.index,
|
||||
"] is greater than the number of total outputs [", ic->num_outputs(),
|
||||
"].");
|
||||
return;
|
||||
}
|
||||
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
||||
|
||||
tensorflow::shape_inference::InferenceContext* ic_dst =
|
||||
graph->refiner.GetContext(&dst.oper->node);
|
||||
if (ic_dst->num_inputs() <= dst.index) {
|
||||
status->status = tensorflow::errors::OutOfRange(
|
||||
"Cannot update edge. Input index [", dst.index,
|
||||
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
||||
"].");
|
||||
return;
|
||||
}
|
||||
if (!ic_dst->MergeInput(dst.index, shape)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
||||
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
||||
return;
|
||||
}
|
||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||
&dst.oper->node, dst.index);
|
||||
}
|
||||
|
@ -333,7 +333,8 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
|
||||
InferenceContext* c = iter->second->get_context();
|
||||
DCHECK_GE(dst_input, 0);
|
||||
ShapeHandle existing_input = node_context->input(dst_input);
|
||||
if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) {
|
||||
if (!relax && node_context->MergeInput(dst_input, c->output(src_output)) &&
|
||||
!existing_input.SameHandle(node_context->input(dst_input))) {
|
||||
*refined = true;
|
||||
} else if (relax) {
|
||||
if (node_context->RelaxInput(dst_input, c->output(src_output))) {
|
||||
|
@ -1259,7 +1259,17 @@ TEST_F(ShapeRefinerTest, IncrementalUpdates) {
|
||||
EXPECT_FALSE(refined);
|
||||
ctx = m.GetContext(dequeue);
|
||||
EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
|
||||
ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
|
||||
EXPECT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
|
||||
|
||||
// Inject a shape of the same handle and expect refined to not change.
|
||||
ctx = m.GetContext(queue);
|
||||
shape_inference::ShapeHandle shp2 = shp;
|
||||
ctx->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{{shp2, DT_FLOAT}});
|
||||
refined = false;
|
||||
TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined));
|
||||
EXPECT_FALSE(refined);
|
||||
EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0)));
|
||||
}
|
||||
|
||||
void TestSimpleFunctionInference(bool enable_function_inference,
|
||||
|
@ -544,9 +544,10 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
|
||||
return_s1 = false;
|
||||
} else if (v0 != v1) {
|
||||
*out = nullptr;
|
||||
return errors::InvalidArgument("Dimension ", i,
|
||||
" in both shapes must be equal, but are ",
|
||||
Value(d0), " and ", Value(d1));
|
||||
return errors::InvalidArgument(
|
||||
"Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
|
||||
" and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
|
||||
DebugString(s1), ".");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -237,24 +237,19 @@ class InferenceContext {
|
||||
// - For any one dimension, if the values for that dimension in both shapes
|
||||
// are known, then the values must match.
|
||||
// - If one shape has equal or more information than the other shape in every
|
||||
// dimension, the shape with more information will be returned. Otherwise a
|
||||
// new shape holding the combined information of the input shapes will be
|
||||
// returned.
|
||||
// dimension, the new shape will become the shape with more information.
|
||||
// - Example: merging [2,?] and [?,2] results in [2,2]
|
||||
// - Example: [2,2] cannot be merged with [1,2]
|
||||
//
|
||||
// This requires idx to be in the [0, num_inputs) range. If the merge is
|
||||
// successful and the new shape differs from the old one, store the new shape
|
||||
// and return true. Return false otherwise.
|
||||
// successful, return true. Return false otherwise.
|
||||
bool MergeInput(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
|
||||
inputs_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
|
||||
inputs_[idx] = new_shape;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Relax the stored shape of the input in position idx with <shape> according
|
||||
// to the following rules:
|
||||
//
|
||||
|
@ -511,6 +511,13 @@ TEST_F(GraphTest, UpdateEdge) {
|
||||
EXPECT_EQ(
|
||||
s.error_message(),
|
||||
"Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
|
||||
|
||||
// Update a's 1st input which is out of range.
|
||||
s = graph_.UpdateEdge(c, 0, a, 0);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(
|
||||
s.error_message(),
|
||||
"Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, InputEdges) {
|
||||
|
@ -1805,7 +1805,7 @@ class Operation(object):
|
||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||
self._recompute_node_def()
|
||||
|
||||
def _update_input(self, index, tensor, dtype=None):
|
||||
def _update_input(self, index, tensor):
|
||||
"""Update the input to this operation at the given index.
|
||||
|
||||
NOTE: This is for TF internal use only. Please don't use it.
|
||||
@ -1813,8 +1813,6 @@ class Operation(object):
|
||||
Args:
|
||||
index: the index of the input to update.
|
||||
tensor: the Tensor to be used as the input at the given index.
|
||||
dtype: tf.DType: type of the input; defaults to
|
||||
the tensor's dtype.
|
||||
|
||||
Raises:
|
||||
TypeError: if tensor is not a Tensor,
|
||||
@ -1832,17 +1830,9 @@ class Operation(object):
|
||||
self._tf_input(index),
|
||||
status)
|
||||
else:
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
else:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_compatible_with(tensor.dtype):
|
||||
raise TypeError(
|
||||
"Cannot convert a tensor of type %s to an input of type %s" %
|
||||
(tensor.dtype.name, dtype.name))
|
||||
self._inputs[index].consumers().remove(self)
|
||||
self._inputs[index] = tensor
|
||||
self._input_types_val[index] = dtype
|
||||
self._input_types_val[index] = tensor.dtype
|
||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||
self._recompute_node_def()
|
||||
|
||||
|
@ -492,8 +492,6 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||
|
||||
# TODO(nolivia): check the shape/type in _update_input() instead of depending
|
||||
# on run to do that.
|
||||
def testUpdateInputTypeError(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -509,6 +507,37 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
"with expected int32"):
|
||||
sess.run(z)
|
||||
|
||||
def testUpdateInputShapeError(self):
|
||||
# C-API throws the error differently.
|
||||
if ops._USE_C_API:
|
||||
return
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
w = constant_op.constant(2, shape=[3, 1])
|
||||
x = constant_op.constant(0, shape=[3, 1])
|
||||
y = constant_op.constant(1, shape=[2, 2])
|
||||
z = w + x
|
||||
z.op._update_input(0, y) # pylint: disable=protected-access
|
||||
|
||||
with session.Session(graph=g) as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
r"Incompatible shapes: \[2,2\] vs. \[3,1\]"):
|
||||
sess.run(z)
|
||||
|
||||
def testUpdateInputShapeErrorC(self):
|
||||
if not ops._USE_C_API:
|
||||
return
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
w = constant_op.constant(2, shape=[3, 1])
|
||||
x = constant_op.constant(0, shape=[3, 1])
|
||||
y = constant_op.constant(1, shape=[2, 2])
|
||||
z = w + x
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
|
||||
z.op._update_input(0, y) # pylint: disable=protected-access
|
||||
|
||||
def testUpdateInputOutOfRange(self):
|
||||
# C-API throws the error differently.
|
||||
if ops._USE_C_API: return
|
||||
@ -524,9 +553,11 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaisesRegexp(errors.OutOfRangeError,
|
||||
r"Node 'Const' \(type: 'Const', "
|
||||
r"num of inputs: 0\) does not have input 1"):
|
||||
with self.assertRaisesRegexp(
|
||||
errors.OutOfRangeError,
|
||||
r"Cannot update edge. Input index \[1\] is greater than the number of "
|
||||
r"total inputs \[0\]."
|
||||
):
|
||||
x.op._update_input(1, x) # pylint: disable=protected-access
|
||||
|
||||
def testOpDef(self):
|
||||
|
Loading…
Reference in New Issue
Block a user