Validate shapes when updating edges from Python.
Uses MergeInput from shape_inference to check if the new input is compatible with the preexisting shape. Also this changes the MergeInput method. Previously, MergeInput would only return true if the shapes differed *and* the merge was successful. Now, MergeInput returns true only if the merge is successful. PiperOrigin-RevId: 175576173
This commit is contained in:
parent
90222dd7b2
commit
bac56b37be
tensorflow
c
core
common_runtime
framework
graph
python/framework
@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
|||||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
mutex_lock l(graph->mu);
|
mutex_lock l(graph->mu);
|
||||||
|
tensorflow::shape_inference::InferenceContext* ic =
|
||||||
|
graph->refiner.GetContext(&new_src.oper->node);
|
||||||
|
|
||||||
|
if (ic->num_outputs() <= new_src.index) {
|
||||||
|
status->status = tensorflow::errors::OutOfRange(
|
||||||
|
"Cannot update edge. Output index [", new_src.index,
|
||||||
|
"] is greater than the number of total outputs [", ic->num_outputs(),
|
||||||
|
"].");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
||||||
|
|
||||||
|
tensorflow::shape_inference::InferenceContext* ic_dst =
|
||||||
|
graph->refiner.GetContext(&dst.oper->node);
|
||||||
|
if (ic_dst->num_inputs() <= dst.index) {
|
||||||
|
status->status = tensorflow::errors::OutOfRange(
|
||||||
|
"Cannot update edge. Input index [", dst.index,
|
||||||
|
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
||||||
|
"].");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!ic_dst->MergeInput(dst.index, shape)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
||||||
|
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
||||||
|
return;
|
||||||
|
}
|
||||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||||
&dst.oper->node, dst.index);
|
&dst.oper->node, dst.index);
|
||||||
}
|
}
|
||||||
|
@ -333,7 +333,8 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
|
|||||||
InferenceContext* c = iter->second->get_context();
|
InferenceContext* c = iter->second->get_context();
|
||||||
DCHECK_GE(dst_input, 0);
|
DCHECK_GE(dst_input, 0);
|
||||||
ShapeHandle existing_input = node_context->input(dst_input);
|
ShapeHandle existing_input = node_context->input(dst_input);
|
||||||
if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) {
|
if (!relax && node_context->MergeInput(dst_input, c->output(src_output)) &&
|
||||||
|
!existing_input.SameHandle(node_context->input(dst_input))) {
|
||||||
*refined = true;
|
*refined = true;
|
||||||
} else if (relax) {
|
} else if (relax) {
|
||||||
if (node_context->RelaxInput(dst_input, c->output(src_output))) {
|
if (node_context->RelaxInput(dst_input, c->output(src_output))) {
|
||||||
|
@ -1259,7 +1259,17 @@ TEST_F(ShapeRefinerTest, IncrementalUpdates) {
|
|||||||
EXPECT_FALSE(refined);
|
EXPECT_FALSE(refined);
|
||||||
ctx = m.GetContext(dequeue);
|
ctx = m.GetContext(dequeue);
|
||||||
EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
|
EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
|
||||||
ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
|
EXPECT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
|
||||||
|
|
||||||
|
// Inject a shape of the same handle and expect refined to not change.
|
||||||
|
ctx = m.GetContext(queue);
|
||||||
|
shape_inference::ShapeHandle shp2 = shp;
|
||||||
|
ctx->set_output_handle_shapes_and_types(
|
||||||
|
0, std::vector<shape_inference::ShapeAndType>{{shp2, DT_FLOAT}});
|
||||||
|
refined = false;
|
||||||
|
TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined));
|
||||||
|
EXPECT_FALSE(refined);
|
||||||
|
EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSimpleFunctionInference(bool enable_function_inference,
|
void TestSimpleFunctionInference(bool enable_function_inference,
|
||||||
|
@ -544,9 +544,10 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
|
|||||||
return_s1 = false;
|
return_s1 = false;
|
||||||
} else if (v0 != v1) {
|
} else if (v0 != v1) {
|
||||||
*out = nullptr;
|
*out = nullptr;
|
||||||
return errors::InvalidArgument("Dimension ", i,
|
return errors::InvalidArgument(
|
||||||
" in both shapes must be equal, but are ",
|
"Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
|
||||||
Value(d0), " and ", Value(d1));
|
" and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
|
||||||
|
DebugString(s1), ".");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,24 +237,19 @@ class InferenceContext {
|
|||||||
// - For any one dimension, if the values for that dimension in both shapes
|
// - For any one dimension, if the values for that dimension in both shapes
|
||||||
// are known, then the values must match.
|
// are known, then the values must match.
|
||||||
// - If one shape has equal or more information than the other shape in every
|
// - If one shape has equal or more information than the other shape in every
|
||||||
// dimension, the shape with more information will be returned. Otherwise a
|
// dimension, the new shape will become the shape with more information.
|
||||||
// new shape holding the combined information of the input shapes will be
|
|
||||||
// returned.
|
|
||||||
// - Example: merging [2,?] and [?,2] results in [2,2]
|
// - Example: merging [2,?] and [?,2] results in [2,2]
|
||||||
// - Example: [2,2] cannot be merged with [1,2]
|
// - Example: [2,2] cannot be merged with [1,2]
|
||||||
//
|
//
|
||||||
// This requires idx to be in the [0, num_inputs) range. If the merge is
|
// This requires idx to be in the [0, num_inputs) range. If the merge is
|
||||||
// successful and the new shape differs from the old one, store the new shape
|
// successful, return true. Return false otherwise.
|
||||||
// and return true. Return false otherwise.
|
|
||||||
bool MergeInput(int idx, ShapeHandle shape) {
|
bool MergeInput(int idx, ShapeHandle shape) {
|
||||||
ShapeHandle new_shape;
|
ShapeHandle new_shape;
|
||||||
if (!Merge(inputs_[idx], shape, &new_shape).ok() ||
|
if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
|
||||||
inputs_[idx].SameHandle(new_shape)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
inputs_[idx] = new_shape;
|
inputs_[idx] = new_shape;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Relax the stored shape of the input in position idx with <shape> according
|
// Relax the stored shape of the input in position idx with <shape> according
|
||||||
// to the following rules:
|
// to the following rules:
|
||||||
//
|
//
|
||||||
|
@ -511,6 +511,13 @@ TEST_F(GraphTest, UpdateEdge) {
|
|||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
s.error_message(),
|
s.error_message(),
|
||||||
"Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
|
"Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
|
||||||
|
|
||||||
|
// Update a's 1st input which is out of range.
|
||||||
|
s = graph_.UpdateEdge(c, 0, a, 0);
|
||||||
|
EXPECT_FALSE(s.ok());
|
||||||
|
EXPECT_EQ(
|
||||||
|
s.error_message(),
|
||||||
|
"Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GraphTest, InputEdges) {
|
TEST_F(GraphTest, InputEdges) {
|
||||||
|
@ -1805,7 +1805,7 @@ class Operation(object):
|
|||||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||||
self._recompute_node_def()
|
self._recompute_node_def()
|
||||||
|
|
||||||
def _update_input(self, index, tensor, dtype=None):
|
def _update_input(self, index, tensor):
|
||||||
"""Update the input to this operation at the given index.
|
"""Update the input to this operation at the given index.
|
||||||
|
|
||||||
NOTE: This is for TF internal use only. Please don't use it.
|
NOTE: This is for TF internal use only. Please don't use it.
|
||||||
@ -1813,8 +1813,6 @@ class Operation(object):
|
|||||||
Args:
|
Args:
|
||||||
index: the index of the input to update.
|
index: the index of the input to update.
|
||||||
tensor: the Tensor to be used as the input at the given index.
|
tensor: the Tensor to be used as the input at the given index.
|
||||||
dtype: tf.DType: type of the input; defaults to
|
|
||||||
the tensor's dtype.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if tensor is not a Tensor,
|
TypeError: if tensor is not a Tensor,
|
||||||
@ -1832,17 +1830,9 @@ class Operation(object):
|
|||||||
self._tf_input(index),
|
self._tf_input(index),
|
||||||
status)
|
status)
|
||||||
else:
|
else:
|
||||||
if dtype is None:
|
|
||||||
dtype = tensor.dtype
|
|
||||||
else:
|
|
||||||
dtype = dtypes.as_dtype(dtype)
|
|
||||||
if not dtype.is_compatible_with(tensor.dtype):
|
|
||||||
raise TypeError(
|
|
||||||
"Cannot convert a tensor of type %s to an input of type %s" %
|
|
||||||
(tensor.dtype.name, dtype.name))
|
|
||||||
self._inputs[index].consumers().remove(self)
|
self._inputs[index].consumers().remove(self)
|
||||||
self._inputs[index] = tensor
|
self._inputs[index] = tensor
|
||||||
self._input_types_val[index] = dtype
|
self._input_types_val[index] = tensor.dtype
|
||||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||||
self._recompute_node_def()
|
self._recompute_node_def()
|
||||||
|
|
||||||
|
@ -492,8 +492,6 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
|
||||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||||
|
|
||||||
# TODO(nolivia): check the shape/type in _update_input() instead of depending
|
|
||||||
# on run to do that.
|
|
||||||
def testUpdateInputTypeError(self):
|
def testUpdateInputTypeError(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -509,6 +507,37 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
"with expected int32"):
|
"with expected int32"):
|
||||||
sess.run(z)
|
sess.run(z)
|
||||||
|
|
||||||
|
def testUpdateInputShapeError(self):
|
||||||
|
# C-API throws the error differently.
|
||||||
|
if ops._USE_C_API:
|
||||||
|
return
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
w = constant_op.constant(2, shape=[3, 1])
|
||||||
|
x = constant_op.constant(0, shape=[3, 1])
|
||||||
|
y = constant_op.constant(1, shape=[2, 2])
|
||||||
|
z = w + x
|
||||||
|
z.op._update_input(0, y) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
r"Incompatible shapes: \[2,2\] vs. \[3,1\]"):
|
||||||
|
sess.run(z)
|
||||||
|
|
||||||
|
def testUpdateInputShapeErrorC(self):
|
||||||
|
if not ops._USE_C_API:
|
||||||
|
return
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
w = constant_op.constant(2, shape=[3, 1])
|
||||||
|
x = constant_op.constant(0, shape=[3, 1])
|
||||||
|
y = constant_op.constant(1, shape=[2, 2])
|
||||||
|
z = w + x
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
|
||||||
|
z.op._update_input(0, y) # pylint: disable=protected-access
|
||||||
|
|
||||||
def testUpdateInputOutOfRange(self):
|
def testUpdateInputOutOfRange(self):
|
||||||
# C-API throws the error differently.
|
# C-API throws the error differently.
|
||||||
if ops._USE_C_API: return
|
if ops._USE_C_API: return
|
||||||
@ -524,9 +553,11 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
x = constant_op.constant(1)
|
x = constant_op.constant(1)
|
||||||
with self.assertRaisesRegexp(errors.OutOfRangeError,
|
with self.assertRaisesRegexp(
|
||||||
r"Node 'Const' \(type: 'Const', "
|
errors.OutOfRangeError,
|
||||||
r"num of inputs: 0\) does not have input 1"):
|
r"Cannot update edge. Input index \[1\] is greater than the number of "
|
||||||
|
r"total inputs \[0\]."
|
||||||
|
):
|
||||||
x.op._update_input(1, x) # pylint: disable=protected-access
|
x.op._update_input(1, x) # pylint: disable=protected-access
|
||||||
|
|
||||||
def testOpDef(self):
|
def testOpDef(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user