From bac56b37be7736c9da9a3257696a9c1241327d60 Mon Sep 17 00:00:00 2001 From: Olivia Nordquist Date: Mon, 13 Nov 2017 13:07:45 -0800 Subject: [PATCH] Validate shapes when updating edges from Python. Uses MergeInput from shape_inference to check if the new input is compatible with the preexisting shape. Also this changes the MergeInput method. Previously, MergeInput would only return true if the shapes differed *and* the merge was successful. Now, MergeInput returns true only if the merge is successful. PiperOrigin-RevId: 175576173 --- tensorflow/c/python_api.cc | 27 ++++++++++++ .../core/common_runtime/shape_refiner.cc | 3 +- .../core/common_runtime/shape_refiner_test.cc | 12 +++++- tensorflow/core/framework/shape_inference.cc | 7 ++-- tensorflow/core/framework/shape_inference.h | 13 ++---- tensorflow/core/graph/graph_test.cc | 7 ++++ tensorflow/python/framework/ops.py | 14 +------ tensorflow/python/framework/ops_test.py | 41 ++++++++++++++++--- 8 files changed, 93 insertions(+), 31 deletions(-) diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index c67007dca0a..ba5a9268b4f 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); } diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 1ed5eb3f228..8e314c7ea57 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -333,7 +333,8 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { InferenceContext* c = iter->second->get_context(); DCHECK_GE(dst_input, 0); ShapeHandle existing_input = node_context->input(dst_input); - if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) { + if (!relax && node_context->MergeInput(dst_input, c->output(src_output)) && + !existing_input.SameHandle(node_context->input(dst_input))) { *refined = true; } else if (relax) { if (node_context->RelaxInput(dst_input, c->output(src_output))) { diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 676fc7ccedf..ff32e855d59 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -1259,7 +1259,17 @@ TEST_F(ShapeRefinerTest, IncrementalUpdates) { EXPECT_FALSE(refined); ctx = m.GetContext(dequeue); EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0))); - ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0))); + EXPECT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0))); + + // Inject a shape of the same handle and expect refined to not change. + ctx = m.GetContext(queue); + shape_inference::ShapeHandle shp2 = shp; + ctx->set_output_handle_shapes_and_types( + 0, std::vector{{shp2, DT_FLOAT}}); + refined = false; + TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined)); + EXPECT_FALSE(refined); + EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0))); } void TestSimpleFunctionInference(bool enable_function_inference, diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 5d6bf559bb3..fe0742e1db5 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -544,9 +544,10 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, return_s1 = false; } else if (v0 != v1) { *out = nullptr; - return errors::InvalidArgument("Dimension ", i, - " in both shapes must be equal, but are ", - Value(d0), " and ", Value(d1)); + return errors::InvalidArgument( + "Dimension ", i, " in both shapes must be equal, but are ", Value(d0), + " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ", + DebugString(s1), "."); } } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 485980e42ee..b12d37b4c03 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -237,24 +237,19 @@ class InferenceContext { // - For any one dimension, if the values for that dimension in both shapes // are known, then the values must match. // - If one shape has equal or more information than the other shape in every - // dimension, the shape with more information will be returned. Otherwise a - // new shape holding the combined information of the input shapes will be - // returned. + // dimension, the new shape will become the shape with more information. // - Example: merging [2,?] and [?,2] results in [2,2] // - Example: [2,2] cannot be merged with [1,2] // // This requires idx to be in the [0, num_inputs) range. If the merge is - // successful and the new shape differs from the old one, store the new shape - // and return true. Return false otherwise. + // successful, return true. Return false otherwise. bool MergeInput(int idx, ShapeHandle shape) { ShapeHandle new_shape; - if (!Merge(inputs_[idx], shape, &new_shape).ok() || - inputs_[idx].SameHandle(new_shape)) { - return false; - } + if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; inputs_[idx] = new_shape; return true; } + // Relax the stored shape of the input in position idx with according // to the following rules: // diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index e5d57facaa7..7686cef2195 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -511,6 +511,13 @@ TEST_F(GraphTest, UpdateEdge) { EXPECT_EQ( s.error_message(), "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1"); + + // Update a's 1st input which is out of range. + s = graph_.UpdateEdge(c, 0, a, 0); + EXPECT_FALSE(s.ok()); + EXPECT_EQ( + s.error_message(), + "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0"); } TEST_F(GraphTest, InputEdges) { diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index d2608845ac7..b0abbfc7dcc 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1805,7 +1805,7 @@ class Operation(object): tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() - def _update_input(self, index, tensor, dtype=None): + def _update_input(self, index, tensor): """Update the input to this operation at the given index. NOTE: This is for TF internal use only. Please don't use it. @@ -1813,8 +1813,6 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. - dtype: tf.DType: type of the input; defaults to - the tensor's dtype. Raises: TypeError: if tensor is not a Tensor, @@ -1832,17 +1830,9 @@ class Operation(object): self._tf_input(index), status) else: - if dtype is None: - dtype = tensor.dtype - else: - dtype = dtypes.as_dtype(dtype) - if not dtype.is_compatible_with(tensor.dtype): - raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" % - (tensor.dtype.name, dtype.name)) self._inputs[index].consumers().remove(self) self._inputs[index] = tensor - self._input_types_val[index] = dtype + self._input_types_val[index] = tensor.dtype tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 4e931e00c59..1be306ddc59 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -492,8 +492,6 @@ class OperationTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): z.op._update_input(0, x) # pylint: disable=protected-access - # TODO(nolivia): check the shape/type in _update_input() instead of depending - # on run to do that. def testUpdateInputTypeError(self): g = ops.Graph() with g.as_default(): @@ -509,6 +507,37 @@ class OperationTest(test_util.TensorFlowTestCase): "with expected int32"): sess.run(z) + def testUpdateInputShapeError(self): + # C-API throws the error differently. + if ops._USE_C_API: + return + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(2, shape=[3, 1]) + x = constant_op.constant(0, shape=[3, 1]) + y = constant_op.constant(1, shape=[2, 2]) + z = w + x + z.op._update_input(0, y) # pylint: disable=protected-access + + with session.Session(graph=g) as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"Incompatible shapes: \[2,2\] vs. \[3,1\]"): + sess.run(z) + + def testUpdateInputShapeErrorC(self): + if not ops._USE_C_API: + return + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(2, shape=[3, 1]) + x = constant_op.constant(0, shape=[3, 1]) + y = constant_op.constant(1, shape=[2, 2]) + z = w + x + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): + z.op._update_input(0, y) # pylint: disable=protected-access + def testUpdateInputOutOfRange(self): # C-API throws the error differently. if ops._USE_C_API: return @@ -524,9 +553,11 @@ class OperationTest(test_util.TensorFlowTestCase): g = ops.Graph() with g.as_default(): x = constant_op.constant(1) - with self.assertRaisesRegexp(errors.OutOfRangeError, - r"Node 'Const' \(type: 'Const', " - r"num of inputs: 0\) does not have input 1"): + with self.assertRaisesRegexp( + errors.OutOfRangeError, + r"Cannot update edge. Input index \[1\] is greater than the number of " + r"total inputs \[0\]." + ): x.op._update_input(1, x) # pylint: disable=protected-access def testOpDef(self):