Don't try to refine the shapes for a node if its inference context wasn't
successfully built by the AddNode() method. Change: 154838211
This commit is contained in:
parent
883e32600e
commit
e8eafd94de
@ -88,7 +88,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This needs to be filled in with real data in a second pass.
|
// This needs to be filled in with real data in a second pass.
|
||||||
std::vector<const Tensor*> input_tensors(node->num_inputs());
|
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
|
||||||
std::vector<ShapeHandle> input_tensors_as_shapes;
|
std::vector<ShapeHandle> input_tensors_as_shapes;
|
||||||
|
|
||||||
// Create the inference context for this node with the existing input shapes.
|
// Create the inference context for this node with the existing input shapes.
|
||||||
@ -145,6 +145,9 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
|
|||||||
}
|
}
|
||||||
InferenceContext* node_context = it->second.get();
|
InferenceContext* node_context = it->second.get();
|
||||||
|
|
||||||
|
// Give up if the context wasn't successfully built by the AddNode() method.
|
||||||
|
TF_RETURN_IF_ERROR(node_context->construction_status());
|
||||||
|
|
||||||
// Check if the shapes of the nodes in the fan-in of this node have changed,
|
// Check if the shapes of the nodes in the fan-in of this node have changed,
|
||||||
// and if they have update the node input shapes.
|
// and if they have update the node input shapes.
|
||||||
for (const Edge* e : node->in_edges()) {
|
for (const Edge* e : node->in_edges()) {
|
||||||
@ -458,7 +461,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node,
|
|||||||
const OpRegistrationData* op_reg_data,
|
const OpRegistrationData* op_reg_data,
|
||||||
shape_inference::InferenceContext* c) {
|
shape_inference::InferenceContext* c) {
|
||||||
// This will be filled in with real data in a second pass.
|
// This will be filled in with real data in a second pass.
|
||||||
std::vector<const Tensor*> input_tensors(node->num_inputs());
|
std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
|
||||||
std::vector<Tensor> real_tensors(node->num_inputs());
|
std::vector<Tensor> real_tensors(node->num_inputs());
|
||||||
std::vector<bool> attempted_materialization(node->num_inputs());
|
std::vector<bool> attempted_materialization(node->num_inputs());
|
||||||
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
|
std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
|
||||||
|
Loading…
Reference in New Issue
Block a user