pylint fix

This commit is contained in:
Yixing Fu 2020-06-26 10:10:42 -04:00
parent 0de0ac1b11
commit 10957aaf35

View File

@ -74,7 +74,7 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
keras_weight_names,
tf_weight_names,
model_name_tf)
print(f'{keras_block} and {tf_block} match.')
print('{} and {} match.'.format(tf_block, keras_block))
block_mapping = {x[0]: x[1] for x in zip(keras_blocks, tf_blocks)}
@ -94,11 +94,12 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
use_ema=use_ema,
model_name_tf=model_name_tf)
elif 'normalization' in w.name:
print(f'skipping variable {w.name}: normalization is a layer'
'in keras implementation, but preprocessing in TF implementation.')
print('skipping variable {}: normalization is a layer'
'in keras implementation, but preprocessing in '
'TF implementation.'.format(w.name))
continue
else:
raise ValueError(f'{w.name} failed to parse.')
raise ValueError('{} failed to parse.'.format(w.name))
try:
w_tf = tf.train.load_variable(path_ckpt, tf_name)
@ -107,11 +108,13 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
changed_weights += 1
except ValueError as e:
if any([x in w.name for x in ['top', 'predictions', 'probs']]):
warnings.warn(f'Fail to load top layer variable {w.name}'
f'from {tf_name} because of {e}.')
warnings.warn('Fail to load top layer variable {}'
'from {} because of {}.'.format(w.name, tf_name, e))
else:
raise ValueError(f'Fail to load {w.name} from {tf_name}')
print(f'{changed_weights}/{len(keras_model.weights)} weights updated')
raise ValueError('Fail to load {} from {}'.format(w.name, tf_name))
total_weights = len(keras_model.weights)
print('{}/{} weights updated'.format(changed_weights, total_weights))
keras_model.save_weights(path_h5)
@ -178,30 +181,34 @@ def keras_name_to_tf_name_stem_top(keras_name,
ema = ''
stem_top_dict = {
'probs/bias:0': f'{model_name_tf}/head/dense/bias{ema}',
'probs/kernel:0': f'{model_name_tf}/head/dense/kernel{ema}',
'predictions/bias:0': f'{model_name_tf}/head/dense/bias{ema}',
'predictions/kernel:0': f'{model_name_tf}/head/dense/kernel{ema}',
'stem_conv/kernel:0': f'{model_name_tf}/stem/conv2d/kernel{ema}',
'top_conv/kernel:0': f'{model_name_tf}/head/conv2d/kernel{ema}',
'probs/bias:0': '{}/head/dense/bias{}',
'probs/kernel:0': '{}/head/dense/kernel{}',
'predictions/bias:0': '{}/head/dense/bias{}',
'predictions/kernel:0': '{}/head/dense/kernel{}',
'stem_conv/kernel:0': '{}/stem/conv2d/kernel{}',
'top_conv/kernel:0': '{}/head/conv2d/kernel{}',
}
for x in stem_top_dict:
stem_top_dict[x] = stem_top_dict[x].format(model_name_tf, ema)
# stem batch normalization
for bn_weights in ['beta', 'gamma', 'moving_mean', 'moving_variance']:
f_string = '{}/stem/tpu_batch_normalization/{}{}'
stem_top_dict[f'stem_bn/{bn_weights}:0'] = f_string.format(model_name_tf,
bn_weights,
ema)
tf_name = '{}/stem/tpu_batch_normalization/{}{}'.format(model_name_tf,
bn_weights,
ema)
stem_top_dict['stem_bn/{}:0'.format(bn_weights)] = tf_name
# top / head batch normalization
for bn_weights in ['beta', 'gamma', 'moving_mean', 'moving_variance']:
f_string = '{}/head/tpu_batch_normalization/{}{}'
stem_top_dict[f'top_bn/{bn_weights}:0'] = f_string.format(model_name_tf,
bn_weights,
ema)
tf_name = '{}/head/tpu_batch_normalization/{}{}'.format(model_name_tf,
bn_weights,
ema)
stem_top_dict['top_bn/{}:0'.format(bn_weights)] = tf_name
if keras_name in stem_top_dict:
return stem_top_dict[keras_name]
raise KeyError(f'{keras_name} from h5 file cannot be parsed')
raise KeyError('{} from h5 file cannot be parsed'.format(keras_name))
def keras_name_to_tf_name_block(keras_name,
@ -228,8 +235,9 @@ def keras_name_to_tf_name_block(keras_name,
ValueError if keras_block does not show up in keras_name
"""
if f'{keras_block}' not in keras_name:
raise ValueError(f'block name {keras_block} not found in {keras_name}')
if keras_block not in keras_name:
raise ValueError('block name {} not found in {}'
.format(keras_block, keras_name))
# all blocks in the first group will not have expand conv and bn
is_first_blocks = (keras_block[5] == '1')
@ -322,14 +330,14 @@ def check_match(keras_block,
names_from_tf.add(x)
names_missing = names_from_keras - names_from_tf
if len(names_missing) > 0:
raise ValueError(f'{len(names_missing)} variables not found'
f'in checkpoint file: {names_missing}')
if names_missing:
raise ValueError('{} variables not found in checkpoint file: {}'
.format(len(names_missing), names_missing))
names_unused = names_from_tf - names_from_keras
if len(names_unused) > 0:
warnings.warn(f'{len(names_unused)} variables from checkpoint file'
f'are not used: {names_unused}')
if names_unused:
warnings.warn('{} variables from checkpoint file are not used: {}'
.format(len(names_unused), names_unused))
if __name__ == '__main__':