Validate shapes when updating edges from Python.

Uses MergeInput from shape_inference to check if the new input is compatible with the preexisting shape.  Also this changes the MergeInput method.  Previously, MergeInput would only return true if the shapes differed *and* the merge was successful.  Now, MergeInput returns true only if the merge is successful.

PiperOrigin-RevId: 175576173
This commit is contained in:
Olivia Nordquist 2017-11-13 13:07:45 -08:00 committed by TensorFlower Gardener
parent 90222dd7b2
commit bac56b37be
8 changed files with 93 additions and 31 deletions

View File

@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
}

View File

@ -333,7 +333,8 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
InferenceContext* c = iter->second->get_context();
DCHECK_GE(dst_input, 0);
ShapeHandle existing_input = node_context->input(dst_input);
if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) {
if (!relax && node_context->MergeInput(dst_input, c->output(src_output)) &&
!existing_input.SameHandle(node_context->input(dst_input))) {
*refined = true;
} else if (relax) {
if (node_context->RelaxInput(dst_input, c->output(src_output))) {

View File

@ -1259,7 +1259,17 @@ TEST_F(ShapeRefinerTest, IncrementalUpdates) {
EXPECT_FALSE(refined);
ctx = m.GetContext(dequeue);
EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
EXPECT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
// Inject a shape of the same handle and expect refined to not change.
ctx = m.GetContext(queue);
shape_inference::ShapeHandle shp2 = shp;
ctx->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{shp2, DT_FLOAT}});
refined = false;
TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined));
EXPECT_FALSE(refined);
EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0)));
}
void TestSimpleFunctionInference(bool enable_function_inference,

View File

@ -544,9 +544,10 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
return_s1 = false;
} else if (v0 != v1) {
*out = nullptr;
return errors::InvalidArgument("Dimension ", i,
" in both shapes must be equal, but are ",
Value(d0), " and ", Value(d1));
return errors::InvalidArgument(
"Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
" and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
DebugString(s1), ".");
}
}

View File

@ -237,24 +237,19 @@ class InferenceContext {
// - For any one dimension, if the values for that dimension in both shapes
// are known, then the values must match.
// - If one shape has equal or more information than the other shape in every
// dimension, the shape with more information will be returned. Otherwise a
// new shape holding the combined information of the input shapes will be
// returned.
// dimension, the new shape will become the shape with more information.
// - Example: merging [2,?] and [?,2] results in [2,2]
// - Example: [2,2] cannot be merged with [1,2]
//
// This requires idx to be in the [0, num_inputs) range. If the merge is
// successful and the new shape differs from the old one, store the new shape
// and return true. Return false otherwise.
// successful, return true. Return false otherwise.
bool MergeInput(int idx, ShapeHandle shape) {
ShapeHandle new_shape;
if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
inputs_[idx].SameHandle(new_shape)) {
return false;
}
if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
inputs_[idx] = new_shape;
return true;
}
// Relax the stored shape of the input in position idx with <shape> according
// to the following rules:
//

View File

@ -511,6 +511,13 @@ TEST_F(GraphTest, UpdateEdge) {
EXPECT_EQ(
s.error_message(),
"Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
// Update a's 1st input which is out of range.
s = graph_.UpdateEdge(c, 0, a, 0);
EXPECT_FALSE(s.ok());
EXPECT_EQ(
s.error_message(),
"Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
}
TEST_F(GraphTest, InputEdges) {

View File

@ -1805,7 +1805,7 @@ class Operation(object):
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
def _update_input(self, index, tensor, dtype=None):
def _update_input(self, index, tensor):
"""Update the input to this operation at the given index.
NOTE: This is for TF internal use only. Please don't use it.
@ -1813,8 +1813,6 @@ class Operation(object):
Args:
index: the index of the input to update.
tensor: the Tensor to be used as the input at the given index.
dtype: tf.DType: type of the input; defaults to
the tensor's dtype.
Raises:
TypeError: if tensor is not a Tensor,
@ -1832,17 +1830,9 @@ class Operation(object):
self._tf_input(index),
status)
else:
if dtype is None:
dtype = tensor.dtype
else:
dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype):
raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s" %
(tensor.dtype.name, dtype.name))
self._inputs[index].consumers().remove(self)
self._inputs[index] = tensor
self._input_types_val[index] = dtype
self._input_types_val[index] = tensor.dtype
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()

View File

@ -492,8 +492,6 @@ class OperationTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
z.op._update_input(0, x) # pylint: disable=protected-access
# TODO(nolivia): check the shape/type in _update_input() instead of depending
# on run to do that.
def testUpdateInputTypeError(self):
g = ops.Graph()
with g.as_default():
@ -509,6 +507,37 @@ class OperationTest(test_util.TensorFlowTestCase):
"with expected int32"):
sess.run(z)
def testUpdateInputShapeError(self):
# C-API throws the error differently.
if ops._USE_C_API:
return
g = ops.Graph()
with g.as_default():
w = constant_op.constant(2, shape=[3, 1])
x = constant_op.constant(0, shape=[3, 1])
y = constant_op.constant(1, shape=[2, 2])
z = w + x
z.op._update_input(0, y) # pylint: disable=protected-access
with session.Session(graph=g) as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r"Incompatible shapes: \[2,2\] vs. \[3,1\]"):
sess.run(z)
def testUpdateInputShapeErrorC(self):
if not ops._USE_C_API:
return
g = ops.Graph()
with g.as_default():
w = constant_op.constant(2, shape=[3, 1])
x = constant_op.constant(0, shape=[3, 1])
y = constant_op.constant(1, shape=[2, 2])
z = w + x
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
z.op._update_input(0, y) # pylint: disable=protected-access
def testUpdateInputOutOfRange(self):
# C-API throws the error differently.
if ops._USE_C_API: return
@ -524,9 +553,11 @@ class OperationTest(test_util.TensorFlowTestCase):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
with self.assertRaisesRegexp(errors.OutOfRangeError,
r"Node 'Const' \(type: 'Const', "
r"num of inputs: 0\) does not have input 1"):
with self.assertRaisesRegexp(
errors.OutOfRangeError,
r"Cannot update edge. Input index \[1\] is greater than the number of "
r"total inputs \[0\]."
):
x.op._update_input(1, x) # pylint: disable=protected-access
def testOpDef(self):