Replace set with ObjectIdentitySet to prepare for eq change in TF

PiperOrigin-RevId: 262060827
This commit is contained in:
Yanhua Sun 2019-08-06 21:31:22 -07:00 committed by TensorFlower Gardener
parent 7d6b60f1ed
commit 0a25f061dd

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.conv_utils import convert_kernel
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import keras_export
@ -75,7 +76,10 @@ def count_params(weights):
Returns:
The total number of scalars composing the weights
"""
return int(sum(np.prod(p.shape.as_list()) for p in set(weights)))
return int(
sum(
np.prod(p.shape.as_list())
for p in object_identity.ObjectIdentitySet(weights)))
def print_summary(model, line_length=None, positions=None, print_fn=None):