pylint fix
This commit is contained in:
parent
0de0ac1b11
commit
10957aaf35
@ -74,7 +74,7 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
|
||||
keras_weight_names,
|
||||
tf_weight_names,
|
||||
model_name_tf)
|
||||
print(f'{keras_block} and {tf_block} match.')
|
||||
print('{} and {} match.'.format(tf_block, keras_block))
|
||||
|
||||
block_mapping = {x[0]: x[1] for x in zip(keras_blocks, tf_blocks)}
|
||||
|
||||
@ -94,11 +94,12 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
|
||||
use_ema=use_ema,
|
||||
model_name_tf=model_name_tf)
|
||||
elif 'normalization' in w.name:
|
||||
print(f'skipping variable {w.name}: normalization is a layer'
|
||||
'in keras implementation, but preprocessing in TF implementation.')
|
||||
print('skipping variable {}: normalization is a layer'
|
||||
'in keras implementation, but preprocessing in '
|
||||
'TF implementation.'.format(w.name))
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f'{w.name} failed to parse.')
|
||||
raise ValueError('{} failed to parse.'.format(w.name))
|
||||
|
||||
try:
|
||||
w_tf = tf.train.load_variable(path_ckpt, tf_name)
|
||||
@ -107,11 +108,13 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
|
||||
changed_weights += 1
|
||||
except ValueError as e:
|
||||
if any([x in w.name for x in ['top', 'predictions', 'probs']]):
|
||||
warnings.warn(f'Fail to load top layer variable {w.name}'
|
||||
f'from {tf_name} because of {e}.')
|
||||
warnings.warn('Fail to load top layer variable {}'
|
||||
'from {} because of {}.'.format(w.name, tf_name, e))
|
||||
else:
|
||||
raise ValueError(f'Fail to load {w.name} from {tf_name}')
|
||||
print(f'{changed_weights}/{len(keras_model.weights)} weights updated')
|
||||
raise ValueError('Fail to load {} from {}'.format(w.name, tf_name))
|
||||
|
||||
total_weights = len(keras_model.weights)
|
||||
print('{}/{} weights updated'.format(changed_weights, total_weights))
|
||||
keras_model.save_weights(path_h5)
|
||||
|
||||
|
||||
@ -178,30 +181,34 @@ def keras_name_to_tf_name_stem_top(keras_name,
|
||||
ema = ''
|
||||
|
||||
stem_top_dict = {
|
||||
'probs/bias:0': f'{model_name_tf}/head/dense/bias{ema}',
|
||||
'probs/kernel:0': f'{model_name_tf}/head/dense/kernel{ema}',
|
||||
'predictions/bias:0': f'{model_name_tf}/head/dense/bias{ema}',
|
||||
'predictions/kernel:0': f'{model_name_tf}/head/dense/kernel{ema}',
|
||||
'stem_conv/kernel:0': f'{model_name_tf}/stem/conv2d/kernel{ema}',
|
||||
'top_conv/kernel:0': f'{model_name_tf}/head/conv2d/kernel{ema}',
|
||||
'probs/bias:0': '{}/head/dense/bias{}',
|
||||
'probs/kernel:0': '{}/head/dense/kernel{}',
|
||||
'predictions/bias:0': '{}/head/dense/bias{}',
|
||||
'predictions/kernel:0': '{}/head/dense/kernel{}',
|
||||
'stem_conv/kernel:0': '{}/stem/conv2d/kernel{}',
|
||||
'top_conv/kernel:0': '{}/head/conv2d/kernel{}',
|
||||
}
|
||||
for x in stem_top_dict:
|
||||
stem_top_dict[x] = stem_top_dict[x].format(model_name_tf, ema)
|
||||
|
||||
|
||||
# stem batch normalization
|
||||
for bn_weights in ['beta', 'gamma', 'moving_mean', 'moving_variance']:
|
||||
f_string = '{}/stem/tpu_batch_normalization/{}{}'
|
||||
stem_top_dict[f'stem_bn/{bn_weights}:0'] = f_string.format(model_name_tf,
|
||||
bn_weights,
|
||||
ema)
|
||||
tf_name = '{}/stem/tpu_batch_normalization/{}{}'.format(model_name_tf,
|
||||
bn_weights,
|
||||
ema)
|
||||
stem_top_dict['stem_bn/{}:0'.format(bn_weights)] = tf_name
|
||||
|
||||
# top / head batch normalization
|
||||
for bn_weights in ['beta', 'gamma', 'moving_mean', 'moving_variance']:
|
||||
f_string = '{}/head/tpu_batch_normalization/{}{}'
|
||||
stem_top_dict[f'top_bn/{bn_weights}:0'] = f_string.format(model_name_tf,
|
||||
bn_weights,
|
||||
ema)
|
||||
tf_name = '{}/head/tpu_batch_normalization/{}{}'.format(model_name_tf,
|
||||
bn_weights,
|
||||
ema)
|
||||
stem_top_dict['top_bn/{}:0'.format(bn_weights)] = tf_name
|
||||
|
||||
if keras_name in stem_top_dict:
|
||||
return stem_top_dict[keras_name]
|
||||
raise KeyError(f'{keras_name} from h5 file cannot be parsed')
|
||||
raise KeyError('{} from h5 file cannot be parsed'.format(keras_name))
|
||||
|
||||
|
||||
def keras_name_to_tf_name_block(keras_name,
|
||||
@ -228,8 +235,9 @@ def keras_name_to_tf_name_block(keras_name,
|
||||
ValueError if keras_block does not show up in keras_name
|
||||
"""
|
||||
|
||||
if f'{keras_block}' not in keras_name:
|
||||
raise ValueError(f'block name {keras_block} not found in {keras_name}')
|
||||
if keras_block not in keras_name:
|
||||
raise ValueError('block name {} not found in {}'
|
||||
.format(keras_block, keras_name))
|
||||
|
||||
# all blocks in the first group will not have expand conv and bn
|
||||
is_first_blocks = (keras_block[5] == '1')
|
||||
@ -322,14 +330,14 @@ def check_match(keras_block,
|
||||
names_from_tf.add(x)
|
||||
|
||||
names_missing = names_from_keras - names_from_tf
|
||||
if len(names_missing) > 0:
|
||||
raise ValueError(f'{len(names_missing)} variables not found'
|
||||
f'in checkpoint file: {names_missing}')
|
||||
if names_missing:
|
||||
raise ValueError('{} variables not found in checkpoint file: {}'
|
||||
.format(len(names_missing), names_missing))
|
||||
|
||||
names_unused = names_from_tf - names_from_keras
|
||||
if len(names_unused) > 0:
|
||||
warnings.warn(f'{len(names_unused)} variables from checkpoint file'
|
||||
f'are not used: {names_unused}')
|
||||
if names_unused:
|
||||
warnings.warn('{} variables from checkpoint file are not used: {}'
|
||||
.format(len(names_unused), names_unused))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user