Minor refactoring of conditionals
This commit is contained in:
parent
28195763a0
commit
aa070f5be6
@ -66,11 +66,11 @@ PREPROCESS_INPUT_DOC = """
|
|||||||
{ret}
|
{ret}
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: In case of unknown `data_format` argument.
|
ValueError: In case of unknown `mode` or `data_format` argument.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PREPROCESS_INPUT_MODE_DOC = """
|
PREPROCESS_INPUT_MODE_DOC = """
|
||||||
mode: One of "caffe", "tf" or "torch".
|
mode: One of "caffe", "tf" or "torch". Defaults to "caffe".
|
||||||
- caffe: will convert the images from RGB to BGR,
|
- caffe: will convert the images from RGB to BGR,
|
||||||
then will zero-center each color channel with
|
then will zero-center each color channel with
|
||||||
respect to the ImageNet dataset,
|
respect to the ImageNet dataset,
|
||||||
@ -97,9 +97,12 @@ PREPROCESS_INPUT_RET_DOC_CAFFE = """
|
|||||||
@keras_export('keras.applications.imagenet_utils.preprocess_input')
|
@keras_export('keras.applications.imagenet_utils.preprocess_input')
|
||||||
def preprocess_input(x, data_format=None, mode='caffe'):
|
def preprocess_input(x, data_format=None, mode='caffe'):
|
||||||
"""Preprocesses a tensor or Numpy array encoding a batch of images."""
|
"""Preprocesses a tensor or Numpy array encoding a batch of images."""
|
||||||
|
if mode not in {'caffe', 'tf','torch'}:
|
||||||
|
raise ValueError('Unknown mode ' + str(mode))
|
||||||
|
|
||||||
if data_format is None:
|
if data_format is None:
|
||||||
data_format = backend.image_data_format()
|
data_format = backend.image_data_format()
|
||||||
if data_format not in {'channels_first', 'channels_last'}:
|
elif data_format not in {'channels_first', 'channels_last'}:
|
||||||
raise ValueError('Unknown data_format ' + str(data_format))
|
raise ValueError('Unknown data_format ' + str(data_format))
|
||||||
|
|
||||||
if isinstance(x, np.ndarray):
|
if isinstance(x, np.ndarray):
|
||||||
@ -182,8 +185,7 @@ def _preprocess_numpy_input(x, data_format, mode):
|
|||||||
x /= 127.5
|
x /= 127.5
|
||||||
x -= 1.
|
x -= 1.
|
||||||
return x
|
return x
|
||||||
|
elif mode == 'torch':
|
||||||
if mode == 'torch':
|
|
||||||
x /= 255.
|
x /= 255.
|
||||||
mean = [0.485, 0.456, 0.406]
|
mean = [0.485, 0.456, 0.406]
|
||||||
std = [0.229, 0.224, 0.225]
|
std = [0.229, 0.224, 0.225]
|
||||||
@ -253,8 +255,7 @@ def _preprocess_symbolic_input(x, data_format, mode):
|
|||||||
x /= 127.5
|
x /= 127.5
|
||||||
x -= 1.
|
x -= 1.
|
||||||
return x
|
return x
|
||||||
|
elif mode == 'torch':
|
||||||
if mode == 'torch':
|
|
||||||
x /= 255.
|
x /= 255.
|
||||||
mean = [0.485, 0.456, 0.406]
|
mean = [0.485, 0.456, 0.406]
|
||||||
std = [0.229, 0.224, 0.225]
|
std = [0.229, 0.224, 0.225]
|
||||||
@ -414,10 +415,10 @@ def validate_activation(classifier_activation, weights):
|
|||||||
return
|
return
|
||||||
|
|
||||||
classifier_activation = activations.get(classifier_activation)
|
classifier_activation = activations.get(classifier_activation)
|
||||||
if classifier_activation not in [
|
if classifier_activation not in {
|
||||||
activations.get('softmax'),
|
activations.get('softmax'),
|
||||||
activations.get(None)
|
activations.get(None)
|
||||||
]:
|
}:
|
||||||
raise ValueError('Only `None` and `softmax` activations are allowed '
|
raise ValueError('Only `None` and `softmax` activations are allowed '
|
||||||
'for the `classifier_activation` argument when using '
|
'for the `classifier_activation` argument when using '
|
||||||
'pretrained weights, with `include_top=True`')
|
'pretrained weights, with `include_top=True`')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user