#PRIVATE_TF_API_USAGE_CLEANUP Remove the usage of gather_non_trainable_weights. There is no reference to this method, so we just delete the method.
PiperOrigin-RevId: 353274439 Change-Id: I876eac533fd6a68912900118b4159651741e7978
This commit is contained in:
parent
70e20fd2a0
commit
101b68041c
@ -299,38 +299,6 @@ def gather_trainable_weights(trainable, sub_layers, extra_variables):
|
|||||||
return weights + trainable_extra_variables
|
return weights + trainable_extra_variables
|
||||||
|
|
||||||
|
|
||||||
def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
|
|
||||||
"""Lists the non-trainable weights for an object with sub-layers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trainable: Whether the object collecting the variables is trainable.
|
|
||||||
sub_layers: A flat list of Layer objects owned by this object, to collect
|
|
||||||
variables from.
|
|
||||||
extra_variables: Any extra variables to include. Their `.trainable` property
|
|
||||||
is used to categorize them.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of collected non-trainable weights/variables.
|
|
||||||
"""
|
|
||||||
trainable_extra_variables = []
|
|
||||||
non_trainable_extra_variables = []
|
|
||||||
for v in extra_variables:
|
|
||||||
if v.trainable:
|
|
||||||
trainable_extra_variables.append(v)
|
|
||||||
else:
|
|
||||||
non_trainable_extra_variables.append(v)
|
|
||||||
weights = []
|
|
||||||
for layer in sub_layers:
|
|
||||||
weights += layer.non_trainable_weights
|
|
||||||
if not trainable:
|
|
||||||
trainable_weights = []
|
|
||||||
for layer in sub_layers:
|
|
||||||
trainable_weights += layer.trainable_weights
|
|
||||||
return (trainable_weights + trainable_extra_variables
|
|
||||||
+ weights + non_trainable_extra_variables)
|
|
||||||
return weights + non_trainable_extra_variables
|
|
||||||
|
|
||||||
|
|
||||||
def convert_dense_weights_data_format(dense,
|
def convert_dense_weights_data_format(dense,
|
||||||
previous_feature_map_shape,
|
previous_feature_map_shape,
|
||||||
target_data_format='channels_first'):
|
target_data_format='channels_first'):
|
||||||
|
@ -179,35 +179,3 @@ def gather_trainable_weights(trainable, sub_layers, extra_variables):
|
|||||||
trainable_extra_variables = [
|
trainable_extra_variables = [
|
||||||
v for v in extra_variables if v.trainable]
|
v for v in extra_variables if v.trainable]
|
||||||
return weights + trainable_extra_variables
|
return weights + trainable_extra_variables
|
||||||
|
|
||||||
|
|
||||||
def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
|
|
||||||
"""Lists the non-trainable weights for an object with sub-layers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trainable: Whether the object collecting the variables is trainable.
|
|
||||||
sub_layers: A flat list of Layer objects owned by this object, to collect
|
|
||||||
variables from.
|
|
||||||
extra_variables: Any extra variables to include. Their `.trainable` property
|
|
||||||
is used to categorize them.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of collected non-trainable weights/variables.
|
|
||||||
"""
|
|
||||||
trainable_extra_variables = []
|
|
||||||
non_trainable_extra_variables = []
|
|
||||||
for v in extra_variables:
|
|
||||||
if v.trainable:
|
|
||||||
trainable_extra_variables.append(v)
|
|
||||||
else:
|
|
||||||
non_trainable_extra_variables.append(v)
|
|
||||||
weights = []
|
|
||||||
for layer in sub_layers:
|
|
||||||
weights += layer.non_trainable_weights
|
|
||||||
if not trainable:
|
|
||||||
trainable_weights = []
|
|
||||||
for layer in sub_layers:
|
|
||||||
trainable_weights += layer.trainable_weights
|
|
||||||
return (trainable_weights + trainable_extra_variables
|
|
||||||
+ weights + non_trainable_extra_variables)
|
|
||||||
return weights + non_trainable_extra_variables
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user