Avoid doing an equality check on Tensor dimensions
PiperOrigin-RevId: 261015435
This commit is contained in:
parent
1c46da48dc
commit
3351279836
@ -533,7 +533,8 @@ def _reshape_if_necessary(tensor, new_shape):
|
|||||||
new_shape = tuple(-1 if x is None else x for x in new_shape)
|
new_shape = tuple(-1 if x is None else x for x in new_shape)
|
||||||
cur_shape = tuple(x.value for x in tensor.get_shape().dims)
|
cur_shape = tuple(x.value for x in tensor.get_shape().dims)
|
||||||
if (len(new_shape) == len(cur_shape) and
|
if (len(new_shape) == len(cur_shape) and
|
||||||
all(d0 == d1 or d1 == -1 for d0, d1 in zip(cur_shape, new_shape))):
|
all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1)
|
||||||
|
for d0, d1 in zip(cur_shape, new_shape))):
|
||||||
return tensor
|
return tensor
|
||||||
else:
|
else:
|
||||||
return array_ops.reshape(tensor, new_shape)
|
return array_ops.reshape(tensor, new_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user