Fix bug for node with attr output_shape empty.

PiperOrigin-RevId: 322825760
Change-Id: I280fab9cc2670895df4325404734afc9b3f7dafb
This commit is contained in:
A. Unique TensorFlower 2020-07-23 11:17:14 -07:00 committed by TensorFlower Gardener
parent 884bc1c3e8
commit b1c32123aa

View File

@ -612,12 +612,14 @@ class _While(_FunctionCaller):
def convert_variable_to_constant(self, incoming_edge, tensor_data):
super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data)
node = self.converted_self()
node.node.attr["output_shapes"].list.shape[
incoming_edge.destination.index].CopyFrom(
tensor_shape_pb2.TensorShapeProto(dim=[
tensor_shape_pb2.TensorShapeProto.Dim(size=dim)
for dim in tensor_data.numpy.shape
]))
if node.node.attr["output_shapes"].list.shape:
node.node.attr["output_shapes"].list.shape[
incoming_edge.destination.index].CopyFrom(
tensor_shape_pb2.TensorShapeProto(dim=[
tensor_shape_pb2.TensorShapeProto.Dim(size=dim)
for dim in tensor_data.numpy.shape
]))
# The while's body inputs and outputs have the same type, so here we can go
# ahead and change that function's output type.
body_name = self._node.attr["body"].func.name