Keep _Arg nodes in shape inference and propagate shapes in them via the "_output_shapes" attribute, which the shape function for _Arg knows how to decode.

PiperOrigin-RevId: 293399915
Change-Id: Ib77d27ecab508c99a3c8a23145b07a90cbf3841e
This commit is contained in:
A. Unique TensorFlower 2020-02-05 10:35:55 -08:00 committed by TensorFlower Gardener
parent 1a6a334ba7
commit 6d2bfcd2b9

View File

@ -664,14 +664,9 @@ class SymbolicShapeRefiner {
}
}
// Turn _Arg node into a Placeholder. _Arg node is a system op without a
// valid shape function.
*attr_output_shape.mutable_shape() = proto;
fun_node->set_op("Placeholder");
(*fun_node->mutable_attr())["dtype"] = (*fun_node->mutable_attr())["T"];
(*fun_node->mutable_attr()).erase("index");
(*fun_node->mutable_attr()).erase("T");
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
AttrValue output_attr;
output_attr.mutable_list()->add_shape()->Swap(&proto);
(*fun_node->mutable_attr())["_output_shapes"] = output_attr;
}
// Replace input nodes with Consts, if values are known. Note that