Replace set with ObjectIdentitySet to prepare for eq change in TF
PiperOrigin-RevId: 262060827
This commit is contained in:
parent
7d6b60f1ed
commit
0a25f061dd
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.conv_utils import convert_kernel
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -75,7 +76,10 @@ def count_params(weights):
|
||||
Returns:
|
||||
The total number of scalars composing the weights
|
||||
"""
|
||||
return int(sum(np.prod(p.shape.as_list()) for p in set(weights)))
|
||||
return int(
|
||||
sum(
|
||||
np.prod(p.shape.as_list())
|
||||
for p in object_identity.ObjectIdentitySet(weights)))
|
||||
|
||||
|
||||
def print_summary(model, line_length=None, positions=None, print_fn=None):
|
||||
|
Loading…
Reference in New Issue
Block a user