Merge pull request #37438 from jaketae:refactor-imageutils
PiperOrigin-RevId: 313705788 Change-Id: Ifb7da82fb5547cc9241981ca37100cd64b64d7f0
This commit is contained in:
commit
4853a0bbf4
@ -368,7 +368,9 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
||||
DOC = """
|
||||
|
@ -66,11 +66,11 @@ PREPROCESS_INPUT_DOC = """
|
||||
{ret}
|
||||
|
||||
Raises:
|
||||
ValueError: In case of unknown `data_format` argument.
|
||||
{error}
|
||||
"""
|
||||
|
||||
PREPROCESS_INPUT_MODE_DOC = """
|
||||
mode: One of "caffe", "tf" or "torch".
|
||||
mode: One of "caffe", "tf" or "torch". Defaults to "caffe".
|
||||
- caffe: will convert the images from RGB to BGR,
|
||||
then will zero-center each color channel with
|
||||
respect to the ImageNet dataset,
|
||||
@ -82,12 +82,18 @@ PREPROCESS_INPUT_MODE_DOC = """
|
||||
ImageNet dataset.
|
||||
"""
|
||||
|
||||
PREPROCESS_INPUT_DEFAULT_ERROR_DOC = """
|
||||
ValueError: In case of unknown `mode` or `data_format` argument."""
|
||||
|
||||
PREPROCESS_INPUT_ERROR_DOC = """
|
||||
ValueError: In case of unknown `data_format` argument."""
|
||||
|
||||
PREPROCESS_INPUT_RET_DOC_TF = """
|
||||
The inputs pixel values are scaled between -1 and 1, sample-wise."""
|
||||
|
||||
PREPROCESS_INPUT_RET_DOC_TORCH = """
|
||||
The input pixels values are scaled between 0 and 1 and each channel is
|
||||
normalized with respect to the InageNet dataset."""
|
||||
normalized with respect to the ImageNet dataset."""
|
||||
|
||||
PREPROCESS_INPUT_RET_DOC_CAFFE = """
|
||||
The images are converted from RGB to BGR, then each color channel is
|
||||
@ -97,9 +103,12 @@ PREPROCESS_INPUT_RET_DOC_CAFFE = """
|
||||
@keras_export('keras.applications.imagenet_utils.preprocess_input')
|
||||
def preprocess_input(x, data_format=None, mode='caffe'):
|
||||
"""Preprocesses a tensor or Numpy array encoding a batch of images."""
|
||||
if mode not in {'caffe', 'tf', 'torch'}:
|
||||
raise ValueError('Unknown mode ' + str(mode))
|
||||
|
||||
if data_format is None:
|
||||
data_format = backend.image_data_format()
|
||||
if data_format not in {'channels_first', 'channels_last'}:
|
||||
elif data_format not in {'channels_first', 'channels_last'}:
|
||||
raise ValueError('Unknown data_format ' + str(data_format))
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
@ -111,7 +120,9 @@ def preprocess_input(x, data_format=None, mode='caffe'):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = PREPROCESS_INPUT_DOC.format(
|
||||
mode=PREPROCESS_INPUT_MODE_DOC, ret='')
|
||||
mode=PREPROCESS_INPUT_MODE_DOC,
|
||||
ret='',
|
||||
error=PREPROCESS_INPUT_DEFAULT_ERROR_DOC)
|
||||
|
||||
|
||||
@keras_export('keras.applications.imagenet_utils.decode_predictions')
|
||||
@ -182,8 +193,7 @@ def _preprocess_numpy_input(x, data_format, mode):
|
||||
x /= 127.5
|
||||
x -= 1.
|
||||
return x
|
||||
|
||||
if mode == 'torch':
|
||||
elif mode == 'torch':
|
||||
x /= 255.
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
@ -253,8 +263,7 @@ def _preprocess_symbolic_input(x, data_format, mode):
|
||||
x /= 127.5
|
||||
x -= 1.
|
||||
return x
|
||||
|
||||
if mode == 'torch':
|
||||
elif mode == 'torch':
|
||||
x /= 255.
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
@ -414,10 +423,10 @@ def validate_activation(classifier_activation, weights):
|
||||
return
|
||||
|
||||
classifier_activation = activations.get(classifier_activation)
|
||||
if classifier_activation not in [
|
||||
if classifier_activation not in {
|
||||
activations.get('softmax'),
|
||||
activations.get(None)
|
||||
]:
|
||||
}:
|
||||
raise ValueError('Only `None` and `softmax` activations are allowed '
|
||||
'for the `classifier_activation` argument when using '
|
||||
'pretrained weights, with `include_top=True`')
|
||||
|
@ -29,6 +29,11 @@ from tensorflow.python.platform import test
|
||||
class TestImageNetUtils(keras_parameterized.TestCase):
|
||||
|
||||
def test_preprocess_input(self):
|
||||
# Test invalid mode check
|
||||
x = np.random.uniform(0, 255, (10, 10, 3))
|
||||
with self.assertRaises(ValueError):
|
||||
utils.preprocess_input(x, mode='some_unknown_mode')
|
||||
|
||||
# Test image batch with float and int image input
|
||||
x = np.random.uniform(0, 255, (2, 10, 10, 3))
|
||||
xint = x.astype('int32')
|
||||
|
@ -389,5 +389,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -415,5 +415,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -451,5 +451,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -508,5 +508,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -794,5 +794,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -531,7 +531,9 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
||||
DOC = """
|
||||
|
@ -133,7 +133,9 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
||||
DOC = """
|
||||
|
@ -237,5 +237,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -242,5 +242,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
@ -325,5 +325,7 @@ def decode_predictions(preds, top=5):
|
||||
|
||||
|
||||
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
|
||||
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
|
||||
mode='',
|
||||
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
|
||||
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
|
||||
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
|
||||
|
Loading…
Reference in New Issue
Block a user