Port the Mobilenet_v3 to keras/application.

Fix https://github.com/tensorflow/tensorflow/issues/40217.

The implementation is based on https://github.com/keras-team/keras-applications/blob/master/keras_applications/mobilenet_v3.py with a few modifications.

1. Updated to use TF backend only (theano related code is removed).
2. Remove all the *kwargs, and directly use tf.keras packages (disallow package injection).
3. Add 'classifier_activation' which is used by 'top' layer. This is aligned with v1/v2 implementation.
4. [Major] Changed the include_top implementation. The Conv2D layer with name "Conv_2" and its activation is moved to be base model structure, which means they are in the model even the include_top is False. This is based on comparing the implementation detail in original slim implementation in a811a3b7e6/research/slim/nets/mobilenet/mobilenet_v3.py. If we can confirm this change is correct, then we should also fix it on the OSS keras_application as well.
5. [Major] Remove the first ZeroPadding2D layer right after the model input, and change the first conv2D layer to use "same" padding. This is aligned with original implementation in 692215511a/research/slim/nets/mobilenet/mobilenet.py (L155), where use_explicit_padding is False.
6. Added API for preprocess_input and decode_predictions, which aligns with v1 and v2 implementation.

PiperOrigin-RevId: 325734579
Change-Id: I2ba6a9aa695baaa145d1a7cd3aeae86d48b823a2
This commit is contained in:
Scott Zhu 2020-08-09 20:54:59 -07:00 committed by TensorFlower Gardener
parent 0720cbfdd3
commit ef88a7aad4
12 changed files with 640 additions and 0 deletions

View File

@ -101,6 +101,7 @@
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` parameter to FTRL optimizer to match paper.
* Added `mobilenet_v3` to keras application model.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing

View File

@ -23,6 +23,7 @@ keras_packages = [
"tensorflow.python.keras.applications.inception_v3",
"tensorflow.python.keras.applications.mobilenet",
"tensorflow.python.keras.applications.mobilenet_v2",
"tensorflow.python.keras.applications.mobilenet_v3",
"tensorflow.python.keras.applications.nasnet",
"tensorflow.python.keras.applications.resnet",
"tensorflow.python.keras.applications.resnet_v2",

View File

@ -25,6 +25,7 @@ py_library(
"inception_v3.py",
"mobilenet.py",
"mobilenet_v2.py",
"mobilenet_v3.py",
"nasnet.py",
"resnet.py",
"resnet_v2.py",
@ -209,6 +210,22 @@ tf_py_test(
],
)
tf_py_test(
name = "applications_load_weight_test_mobilenet_v3",
srcs = ["applications_load_weight_test.py"],
args = ["--module=mobilenet_v3"],
main = "applications_load_weight_test.py",
tags = [
"no_oss",
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_densenet",
size = "large",

View File

@ -28,6 +28,7 @@ from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.applications import mobilenet_v2
from tensorflow.python.keras.applications import mobilenet_v3
from tensorflow.python.keras.applications import nasnet
from tensorflow.python.keras.applications import resnet
from tensorflow.python.keras.applications import resnet_v2
@ -51,6 +52,8 @@ ARG_TO_MODEL = {
[inception_resnet_v2.InceptionResNetV2]),
'mobilenet': (mobilenet, [mobilenet.MobileNet]),
'mobilenet_v2': (mobilenet_v2, [mobilenet_v2.MobileNetV2]),
'mobilenet_v3': (mobilenet_v3, [mobilenet_v3.MobileNetV3Small,
mobilenet_v3.MobileNetV3Large]),
'densenet': (densenet, [densenet.DenseNet121,
densenet.DenseNet169, densenet.DenseNet201]),
'nasnet_mobile': (nasnet, [nasnet.NASNetMobile]),

View File

@ -27,6 +27,7 @@ from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.applications import mobilenet_v2
from tensorflow.python.keras.applications import mobilenet_v3
from tensorflow.python.keras.applications import nasnet
from tensorflow.python.keras.applications import resnet
from tensorflow.python.keras.applications import resnet_v2
@ -50,6 +51,8 @@ MODEL_LIST_NO_NASNET = [
(inception_resnet_v2.InceptionResNetV2, 1536),
(mobilenet.MobileNet, 1024),
(mobilenet_v2.MobileNetV2, 1280),
(mobilenet_v3.MobileNetV3Small, 1024),
(mobilenet_v3.MobileNetV3Large, 1280),
(densenet.DenseNet121, 1024),
(densenet.DenseNet169, 1664),
(densenet.DenseNet201, 1920),

View File

@ -0,0 +1,567 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
"""MobileNet v3 models for Keras."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
from tensorflow.python.keras import models
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
# TODO(scottzhu): Change this to the GCS path.
BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/mobilenet_v3/')
WEIGHTS_HASHES = {
'large_224_0.75_float': ('765b44a33ad4005b3ac83185abf1d0eb',
'e7b4d1071996dd51a2c2ca2424570e20'),
'large_224_1.0_float': ('59e551e166be033d707958cf9e29a6a7',
'037116398e07f018c0005ffcb0406831'),
'large_minimalistic_224_1.0_float': ('675e7b876c45c57e9e63e6d90a36599c',
'a2c33aed672524d1d0b4431808177695'),
'small_224_0.75_float': ('cb65d4e5be93758266aa0a7f2c6708b7',
'4d2fe46f1c1f38057392514b0df1d673'),
'small_224_1.0_float': ('8768d4c2e7dee89b9d02b2d03d65d862',
'be7100780f875c06bcab93d76641aa26'),
'small_minimalistic_224_1.0_float': ('99cd97fb2fcdad2bf028eb838de69e37',
'20d4e357df3f7a6361f3a288857b1051'),
}
layers = VersionAwareLayers()
BASE_DOCSTRING = """Instantiates the {name} architecture.
Reference:
- [Searching for MobileNetV3](
https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019)
The following table describes the performance of MobileNets:
------------------------------------------------------------------------
MACs stands for Multiply Adds
|Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1|CPU(ms)|
| [mobilenet_v3_large_1.0_224] | 217 | 5.4 | 75.6 | 51.2 |
| [mobilenet_v3_large_0.75_224] | 155 | 4.0 | 73.3 | 39.8 |
| [mobilenet_v3_large_minimalistic_1.0_224] | 209 | 3.9 | 72.3 | 44.1 |
| [mobilenet_v3_small_1.0_224] | 66 | 2.9 | 68.1 | 15.8 |
| [mobilenet_v3_small_0.75_224] | 44 | 2.4 | 65.4 | 12.8 |
| [mobilenet_v3_small_minimalistic_1.0_224] | 65 | 2.0 | 61.9 | 12.2 |
The weights for all 6 models are obtained and translated from the Tensorflow
checkpoints from TensorFlow checkpoints found [here]
(https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet/README.md).
Optionally loads weights pre-trained on ImageNet.
Caution: Be sure to properly pre-process your inputs to the application.
Please see `applications.mobilenet_v3.preprocess_input` for an example.
Arguments:
input_shape: Optional shape tuple, to be specified if you would
like to use a model with an input image resolution that is not
(224, 224, 3).
It should have exactly 3 inputs channels (224, 224, 3).
You can also omit this option if you would like
to infer input_shape from an input_tensor.
If you choose to include both input_tensor and input_shape then
input_shape will be used if they match, if the shapes
do not match then we will throw an error.
E.g. `(160, 160, 3)` would be one valid value.
alpha: controls the width of the network. This is known as the
depth multiplier in the MobileNetV3 paper, but the name is kept for
consistency with MobileNetV1 in Keras.
- If `alpha` < 1.0, proportionally decreases the number
of filters in each layer.
- If `alpha` > 1.0, proportionally increases the number
of filters in each layer.
- If `alpha` = 1, default number of filters from the paper
are used at each layer.
minimalistic: In addition to large and small models this module also
contains so-called minimalistic models, these models have the same
per-layer dimensions characteristic as MobilenetV3 however, they don't
utilize any of the advanced blocks (squeeze-and-excite units, hard-swish,
and 5x5 convolutions). While these models are less efficient on CPU, they
are much more performant on GPU/DSP.
include_top: Boolean, whether to include the fully-connected
layer at the top of the network. Defaults to `True`.
weights: String, one of `None` (random initialization),
'imagenet' (pre-training on ImageNet),
or the path to the weights file to be loaded.
input_tensor: Optional Keras tensor (i.e. output of
`layers.Input()`)
to use as image input for the model.
pooling: String, optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model
will be the 4D tensor output of the
last convolutional block.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a
2D tensor.
- `max` means that global max pooling will
be applied.
classes: Integer, optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
dropout_rate: fraction of the input units to drop on the last layer.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape or invalid alpha, rows when
weights='imagenet'
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
def MobileNetV3(stack_fn,
last_point_ch,
input_shape=None,
alpha=1.0,
model_type='large',
minimalistic=False,
include_top=True,
weights='imagenet',
input_tensor=None,
classes=1000,
pooling=None,
dropout_rate=0.2,
classifier_activation='softmax'):
if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
'(pre-training on ImageNet), '
'or the path to the weights file to be loaded.')
if weights == 'imagenet' and include_top and classes != 1000:
raise ValueError('If using `weights` as `"imagenet"` with `include_top` '
'as true, `classes` should be 1000')
# Determine proper input shape and default size.
# If both input_shape and input_tensor are used, they should match
if input_shape is not None and input_tensor is not None:
try:
is_input_t_tensor = backend.is_keras_tensor(input_tensor)
except ValueError:
try:
is_input_t_tensor = backend.is_keras_tensor(
layer_utils.get_source_inputs(input_tensor))
except ValueError:
raise ValueError('input_tensor: ', input_tensor,
'is not type input_tensor')
if is_input_t_tensor:
if backend.image_data_format == 'channels_first':
if backend.int_shape(input_tensor)[1] != input_shape[1]:
raise ValueError('input_shape: ', input_shape, 'and input_tensor: ',
input_tensor,
'do not meet the same shape requirements')
else:
if backend.int_shape(input_tensor)[2] != input_shape[1]:
raise ValueError('input_shape: ', input_shape, 'and input_tensor: ',
input_tensor,
'do not meet the same shape requirements')
else:
raise ValueError('input_tensor specified: ', input_tensor,
'is not a keras tensor')
# If input_shape is None, infer shape from input_tensor
if input_shape is None and input_tensor is not None:
try:
backend.is_keras_tensor(input_tensor)
except ValueError:
raise ValueError('input_tensor: ', input_tensor, 'is type: ',
type(input_tensor), 'which is not a valid type')
if backend.is_keras_tensor(input_tensor):
if backend.image_data_format() == 'channels_first':
rows = backend.int_shape(input_tensor)[2]
cols = backend.int_shape(input_tensor)[3]
input_shape = (3, cols, rows)
else:
rows = backend.int_shape(input_tensor)[1]
cols = backend.int_shape(input_tensor)[2]
input_shape = (cols, rows, 3)
# If input_shape is None and input_tensor is None using standart shape
if input_shape is None and input_tensor is None:
input_shape = (None, None, 3)
if backend.image_data_format() == 'channels_last':
row_axis, col_axis = (0, 1)
else:
row_axis, col_axis = (1, 2)
rows = input_shape[row_axis]
cols = input_shape[col_axis]
if rows and cols and (rows < 32 or cols < 32):
raise ValueError('Input size must be at least 32x32; got `input_shape=' +
str(input_shape) + '`')
if weights == 'imagenet':
if (not minimalistic and alpha not in [0.75, 1.0]
or minimalistic and alpha != 1.0):
raise ValueError('If imagenet weights are being loaded, '
'alpha can be one of `0.75`, `1.0` for non minimalistic'
' or `1.0` for minimalistic only.')
if rows != cols or rows != 224:
logging.warning('`input_shape` is undefined or non-square, '
'or `rows` is not 224.'
' Weights for input shape (224, 224) will be'
' loaded as the default.')
if input_tensor is None:
img_input = layers.Input(shape=input_shape)
else:
if not backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
if minimalistic:
kernel = 3
activation = relu
se_ratio = None
else:
kernel = 5
activation = hard_swish
se_ratio = 0.25
x = img_input
x = layers.Rescaling(1. / 255.)(x)
x = layers.Conv2D(
16,
kernel_size=3,
strides=(2, 2),
padding='same',
use_bias=False,
name='Conv')(x)
x = layers.BatchNormalization(
axis=channel_axis, epsilon=1e-3,
momentum=0.999, name='Conv/BatchNorm')(x)
x = activation(x)
x = stack_fn(x, kernel, activation, se_ratio)
last_conv_ch = _depth(backend.int_shape(x)[channel_axis] * 6)
# if the width multiplier is greater than 1 we
# increase the number of output channels
if alpha > 1.0:
last_point_ch = _depth(last_point_ch * alpha)
x = layers.Conv2D(
last_conv_ch,
kernel_size=1,
padding='same',
use_bias=False,
name='Conv_1')(x)
x = layers.BatchNormalization(
axis=channel_axis, epsilon=1e-3,
momentum=0.999, name='Conv_1/BatchNorm')(x)
x = activation(x)
x = layers.Conv2D(
last_point_ch,
kernel_size=1,
padding='same',
use_bias=True,
name='Conv_2')(x)
x = activation(x)
if include_top:
x = layers.GlobalAveragePooling2D()(x)
if channel_axis == 1:
x = layers.Reshape((last_point_ch, 1, 1))(x)
else:
x = layers.Reshape((1, 1, last_point_ch))(x)
if dropout_rate > 0:
x = layers.Dropout(dropout_rate)(x)
x = layers.Conv2D(classes, kernel_size=1, padding='same', name='Logits')(x)
x = layers.Flatten()(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Activation(activation=classifier_activation,
name='Predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
elif pooling == 'max':
x = layers.GlobalMaxPooling2D(name='max_pool')(x)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = models.Model(inputs, x, name='MobilenetV3' + model_type)
# Load weights.
if weights == 'imagenet':
model_name = '{}{}_224_{}_float'.format(
model_type, '_minimalistic' if minimalistic else '', str(alpha))
if include_top:
file_name = 'weights_mobilenet_v3_' + model_name + '.h5'
file_hash = WEIGHTS_HASHES[model_name][0]
else:
file_name = 'weights_mobilenet_v3_' + model_name + '_no_top.h5'
file_hash = WEIGHTS_HASHES[model_name][1]
weights_path = data_utils.get_file(
file_name,
BASE_WEIGHT_PATH + file_name,
cache_subdir='models',
file_hash=file_hash)
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)
return model
@keras_export('keras.applications.MobileNetV3Samll')
def MobileNetV3Small(input_shape=None,
alpha=1.0,
minimalistic=False,
include_top=True,
weights='imagenet',
input_tensor=None,
classes=1000,
pooling=None,
dropout_rate=0.2,
classifier_activation='softmax'):
def stack_fn(x, kernel, activation, se_ratio):
def depth(d):
return _depth(d * alpha)
x = _inverted_res_block(x, 1, depth(16), 3, 2, se_ratio, relu, 0)
x = _inverted_res_block(x, 72. / 16, depth(24), 3, 2, None, relu, 1)
x = _inverted_res_block(x, 88. / 24, depth(24), 3, 1, None, relu, 2)
x = _inverted_res_block(x, 4, depth(40), kernel, 2, se_ratio, activation, 3)
x = _inverted_res_block(x, 6, depth(40), kernel, 1, se_ratio, activation, 4)
x = _inverted_res_block(x, 6, depth(40), kernel, 1, se_ratio, activation, 5)
x = _inverted_res_block(x, 3, depth(48), kernel, 1, se_ratio, activation, 6)
x = _inverted_res_block(x, 3, depth(48), kernel, 1, se_ratio, activation, 7)
x = _inverted_res_block(x, 6, depth(96), kernel, 2, se_ratio, activation, 8)
x = _inverted_res_block(x, 6, depth(96), kernel, 1, se_ratio, activation, 9)
x = _inverted_res_block(x, 6, depth(96), kernel, 1, se_ratio, activation,
10)
return x
return MobileNetV3(stack_fn, 1024, input_shape, alpha, 'small', minimalistic,
include_top, weights, input_tensor, classes, pooling,
dropout_rate, classifier_activation)
@keras_export('keras.applications.MobileNetV3Large')
def MobileNetV3Large(input_shape=None,
alpha=1.0,
minimalistic=False,
include_top=True,
weights='imagenet',
input_tensor=None,
classes=1000,
pooling=None,
dropout_rate=0.2,
classifier_activation='softmax'):
def stack_fn(x, kernel, activation, se_ratio):
def depth(d):
return _depth(d * alpha)
x = _inverted_res_block(x, 1, depth(16), 3, 1, None, relu, 0)
x = _inverted_res_block(x, 4, depth(24), 3, 2, None, relu, 1)
x = _inverted_res_block(x, 3, depth(24), 3, 1, None, relu, 2)
x = _inverted_res_block(x, 3, depth(40), kernel, 2, se_ratio, relu, 3)
x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 4)
x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 5)
x = _inverted_res_block(x, 6, depth(80), 3, 2, None, activation, 6)
x = _inverted_res_block(x, 2.5, depth(80), 3, 1, None, activation, 7)
x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 8)
x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 9)
x = _inverted_res_block(x, 6, depth(112), 3, 1, se_ratio, activation, 10)
x = _inverted_res_block(x, 6, depth(112), 3, 1, se_ratio, activation, 11)
x = _inverted_res_block(x, 6, depth(160), kernel, 2, se_ratio, activation,
12)
x = _inverted_res_block(x, 6, depth(160), kernel, 1, se_ratio, activation,
13)
x = _inverted_res_block(x, 6, depth(160), kernel, 1, se_ratio, activation,
14)
return x
return MobileNetV3(stack_fn, 1280, input_shape, alpha, 'large', minimalistic,
include_top, weights, input_tensor, classes, pooling,
dropout_rate, classifier_activation)
MobileNetV3Small.__doc__ = BASE_DOCSTRING.format(name='MobileNetV3Small')
MobileNetV3Large.__doc__ = BASE_DOCSTRING.format(name='MobileNetV3Large')
def relu(x):
return layers.ReLU()(x)
def hard_sigmoid(x):
return layers.ReLU(6.)(x + 3.) * (1. / 6.)
def hard_swish(x):
return layers.Multiply()([hard_sigmoid(x), x])
# This function is taken from the original tf repo.
# It ensures that all layers have a channel number that is divisible by 8
# It can be seen here:
# https://github.com/tensorflow/models/blob/master/research/
# slim/nets/mobilenet/mobilenet.py
def _depth(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def _se_block(inputs, filters, se_ratio, prefix):
x = layers.GlobalAveragePooling2D(name=prefix + 'squeeze_excite/AvgPool')(
inputs)
if backend.image_data_format() == 'channels_first':
x = layers.Reshape((filters, 1, 1))(x)
else:
x = layers.Reshape((1, 1, filters))(x)
x = layers.Conv2D(
_depth(filters * se_ratio),
kernel_size=1,
padding='same',
name=prefix + 'squeeze_excite/Conv')(
x)
x = layers.ReLU(name=prefix + 'squeeze_excite/Relu')(x)
x = layers.Conv2D(
filters,
kernel_size=1,
padding='same',
name=prefix + 'squeeze_excite/Conv_1')(
x)
x = hard_sigmoid(x)
x = layers.Multiply(name=prefix + 'squeeze_excite/Mul')([inputs, x])
return x
def _inverted_res_block(x, expansion, filters, kernel_size, stride, se_ratio,
activation, block_id):
channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
shortcut = x
prefix = 'expanded_conv/'
infilters = backend.int_shape(x)[channel_axis]
if block_id:
# Expand
prefix = 'expanded_conv_{}/'.format(block_id)
x = layers.Conv2D(
_depth(infilters * expansion),
kernel_size=1,
padding='same',
use_bias=False,
name=prefix + 'expand')(
x)
x = layers.BatchNormalization(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + 'expand/BatchNorm')(
x)
x = activation(x)
if stride == 2:
x = layers.ZeroPadding2D(
padding=imagenet_utils.correct_pad(x, kernel_size),
name=prefix + 'depthwise/pad')(
x)
x = layers.DepthwiseConv2D(
kernel_size,
strides=stride,
padding='same' if stride == 1 else 'valid',
use_bias=False,
name=prefix + 'depthwise')(
x)
x = layers.BatchNormalization(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + 'depthwise/BatchNorm')(
x)
x = activation(x)
if se_ratio:
x = _se_block(x, _depth(infilters * expansion), se_ratio, prefix)
x = layers.Conv2D(
filters,
kernel_size=1,
padding='same',
use_bias=False,
name=prefix + 'project')(
x)
x = layers.BatchNormalization(
axis=channel_axis,
epsilon=1e-3,
momentum=0.999,
name=prefix + 'project/BatchNorm')(
x)
if stride == 1 and infilters == filters:
x = layers.Add(name=prefix + 'Add')([shortcut, x])
return x
@keras_export('keras.applications.mobilenet_v3.preprocess_input')
def preprocess_input(x, data_format=None): # pylint: disable=unused-argument
return x
@keras_export('keras.applications.mobilenet_v3.decode_predictions')
def decode_predictions(preds, top=5):
return imagenet_utils.decode_predictions(preds, top=top)
preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
mode='',
ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -97,6 +97,7 @@ KERAS_API_INIT_FILES = [
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
"keras/applications/mobilenet_v2/__init__.py",
"keras/applications/mobilenet_v3/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet/__init__.py",
"keras/applications/resnet_v2/__init__.py",

View File

@ -113,6 +113,7 @@ KERAS_API_INIT_FILES_V1 = [
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
"keras/applications/mobilenet_v2/__init__.py",
"keras/applications/mobilenet_v3/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet/__init__.py",
"keras/applications/resnet_v2/__init__.py",

View File

@ -0,0 +1,11 @@
path: "tensorflow.keras.applications.mobilenet_v3"
tf_module {
member_method {
name: "decode_predictions"
argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
}
member_method {
name: "preprocess_input"
argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -28,6 +28,10 @@ tf_module {
name: "mobilenet_v2"
mtype: "<type \'module\'>"
}
member {
name: "mobilenet_v3"
mtype: "<type \'module\'>"
}
member {
name: "nasnet"
mtype: "<type \'module\'>"
@ -116,6 +120,14 @@ tf_module {
name: "MobileNetV2"
argspec: "args=[\'input_shape\', \'alpha\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\', \'classifier_activation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1.0\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\', \'softmax\'], "
}
member_method {
name: "MobileNetV3Large"
argspec: "args=[\'input_shape\', \'alpha\', \'minimalistic\', \'include_top\', \'weights\', \'input_tensor\', \'classes\', \'pooling\', \'dropout_rate\', \'classifier_activation\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'False\', \'True\', \'imagenet\', \'None\', \'1000\', \'None\', \'0.2\', \'softmax\'], "
}
member_method {
name: "MobileNetV3Samll"
argspec: "args=[\'input_shape\', \'alpha\', \'minimalistic\', \'include_top\', \'weights\', \'input_tensor\', \'classes\', \'pooling\', \'dropout_rate\', \'classifier_activation\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'False\', \'True\', \'imagenet\', \'None\', \'1000\', \'None\', \'0.2\', \'softmax\'], "
}
member_method {
name: "NASNetLarge"
argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "

View File

@ -0,0 +1,11 @@
path: "tensorflow.keras.applications.mobilenet_v3"
tf_module {
member_method {
name: "decode_predictions"
argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
}
member_method {
name: "preprocess_input"
argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -28,6 +28,10 @@ tf_module {
name: "mobilenet_v2"
mtype: "<type \'module\'>"
}
member {
name: "mobilenet_v3"
mtype: "<type \'module\'>"
}
member {
name: "nasnet"
mtype: "<type \'module\'>"
@ -116,6 +120,14 @@ tf_module {
name: "MobileNetV2"
argspec: "args=[\'input_shape\', \'alpha\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\', \'classifier_activation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1.0\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\', \'softmax\'], "
}
member_method {
name: "MobileNetV3Large"
argspec: "args=[\'input_shape\', \'alpha\', \'minimalistic\', \'include_top\', \'weights\', \'input_tensor\', \'classes\', \'pooling\', \'dropout_rate\', \'classifier_activation\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'False\', \'True\', \'imagenet\', \'None\', \'1000\', \'None\', \'0.2\', \'softmax\'], "
}
member_method {
name: "MobileNetV3Samll"
argspec: "args=[\'input_shape\', \'alpha\', \'minimalistic\', \'include_top\', \'weights\', \'input_tensor\', \'classes\', \'pooling\', \'dropout_rate\', \'classifier_activation\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'False\', \'True\', \'imagenet\', \'None\', \'1000\', \'None\', \'0.2\', \'softmax\'], "
}
member_method {
name: "NASNetLarge"
argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "