Merge pull request #37438 from jaketae:refactor-imageutils

PiperOrigin-RevId: 313705788
Change-Id: Ifb7da82fb5547cc9241981ca37100cd64b64d7f0
This commit is contained in:
TensorFlower Gardener 2020-05-28 19:41:08 -07:00
commit 4853a0bbf4
13 changed files with 58 additions and 22 deletions

View File

@ -368,7 +368,9 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
DOC = """

View File

@ -66,11 +66,11 @@ PREPROCESS_INPUT_DOC = """
{ret}
Raises:
ValueError: In case of unknown `data_format` argument.
{error}
"""
PREPROCESS_INPUT_MODE_DOC = """
mode: One of "caffe", "tf" or "torch".
mode: One of "caffe", "tf" or "torch". Defaults to "caffe".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
@ -82,12 +82,18 @@ PREPROCESS_INPUT_MODE_DOC = """
ImageNet dataset.
"""
PREPROCESS_INPUT_DEFAULT_ERROR_DOC = """
ValueError: In case of unknown `mode` or `data_format` argument."""
PREPROCESS_INPUT_ERROR_DOC = """
ValueError: In case of unknown `data_format` argument."""
PREPROCESS_INPUT_RET_DOC_TF = """
The inputs pixel values are scaled between -1 and 1, sample-wise."""
PREPROCESS_INPUT_RET_DOC_TORCH = """
The input pixels values are scaled between 0 and 1 and each channel is
normalized with respect to the InageNet dataset."""
normalized with respect to the ImageNet dataset."""
PREPROCESS_INPUT_RET_DOC_CAFFE = """
The images are converted from RGB to BGR, then each color channel is
@ -97,9 +103,12 @@ PREPROCESS_INPUT_RET_DOC_CAFFE = """
@keras_export('keras.applications.imagenet_utils.preprocess_input')
def preprocess_input(x, data_format=None, mode='caffe'):
"""Preprocesses a tensor or Numpy array encoding a batch of images."""
if mode not in {'caffe', 'tf', 'torch'}:
raise ValueError('Unknown mode ' + str(mode))
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
elif data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
if isinstance(x, np.ndarray):
@ -111,7 +120,9 @@ def preprocess_input(x, data_format=None, mode='caffe'):
preprocess_input.__doc__ = PREPROCESS_INPUT_DOC.format(
mode=PREPROCESS_INPUT_MODE_DOC, ret='')
mode=PREPROCESS_INPUT_MODE_DOC,
ret='',
error=PREPROCESS_INPUT_DEFAULT_ERROR_DOC)
@keras_export('keras.applications.imagenet_utils.decode_predictions')
@ -182,8 +193,7 @@ def _preprocess_numpy_input(x, data_format, mode):
x /= 127.5
x -= 1.
return x
if mode == 'torch':
elif mode == 'torch':
x /= 255.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
@ -253,8 +263,7 @@ def _preprocess_symbolic_input(x, data_format, mode):
x /= 127.5
x -= 1.
return x
if mode == 'torch':
elif mode == 'torch':
x /= 255.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
@ -414,10 +423,10 @@ def validate_activation(classifier_activation, weights):
return
classifier_activation = activations.get(classifier_activation)
if classifier_activation not in [
if classifier_activation not in {
activations.get('softmax'),
activations.get(None)
]:
}:
raise ValueError('Only `None` and `softmax` activations are allowed '
'for the `classifier_activation` argument when using '
'pretrained weights, with `include_top=True`')

View File

@ -29,6 +29,11 @@ from tensorflow.python.platform import test
class TestImageNetUtils(keras_parameterized.TestCase):
def test_preprocess_input(self):
# Test invalid mode check
x = np.random.uniform(0, 255, (10, 10, 3))
with self.assertRaises(ValueError):
utils.preprocess_input(x, mode='some_unknown_mode')
# Test image batch with float and int image input
x = np.random.uniform(0, 255, (2, 10, 10, 3))
xint = x.astype('int32')

View File

@ -389,5 +389,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -415,5 +415,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -451,5 +451,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -508,5 +508,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -794,5 +794,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -531,7 +531,9 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
DOC = """

View File

@ -133,7 +133,9 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
DOC = """

View File

@ -237,5 +237,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -242,5 +242,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -325,5 +325,7 @@ def decode_predictions(preds, top=5):
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='', ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF)
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__