Add a classifier_activation option to keras.applications

Defaults to `softmax` (the current behavior) but now users have the option of turning it off, or setting a different activation function.

PiperOrigin-RevId: 296235461
Change-Id: Iea01b8a2b260ece5e91b4dc1e6e42526136a2066
This commit is contained in:
Mark Daoust 2020-02-20 10:01:00 -08:00 committed by TensorFlower Gardener
parent e85f354bba
commit 51c182c4a3
13 changed files with 285 additions and 113 deletions

View File

@ -125,13 +125,16 @@ def conv_block(x, growth_rate, name):
return x
def DenseNet(blocks,
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def DenseNet(
blocks,
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the DenseNet architecture.
Optionally loads weights pre-trained on ImageNet.
@ -169,13 +172,18 @@ def DenseNet(blocks,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
@ -228,7 +236,10 @@ def DenseNet(blocks,
if include_top:
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)

View File

@ -141,21 +141,24 @@ DENSE_KERNEL_INITIALIZER = {
}
def EfficientNet(width_coefficient,
depth_coefficient,
default_size,
dropout_rate=0.2,
drop_connect_rate=0.2,
depth_divisor=8,
activation='swish',
blocks_args='default',
model_name='efficientnet',
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def EfficientNet(
width_coefficient,
depth_coefficient,
default_size,
dropout_rate=0.2,
drop_connect_rate=0.2,
depth_divisor=8,
activation='swish',
blocks_args='default',
model_name='efficientnet',
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the EfficientNet architecture using given scaling coefficients.
Optionally loads weights pre-trained on ImageNet.
@ -197,13 +200,18 @@ def EfficientNet(width_coefficient,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if blocks_args == 'default':
blocks_args = DEFAULT_BLOCKS_ARGS
@ -307,11 +315,12 @@ def EfficientNet(width_coefficient,
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
if dropout_rate > 0:
x = layers.Dropout(dropout_rate, name='top_dropout')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(
classes,
activation='softmax',
activation=classifier_activation,
kernel_initializer=DENSE_KERNEL_INITIALIZER,
name='probs')(x)
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)

View File

@ -22,6 +22,7 @@ import warnings
import numpy as np
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.util.tf_export import keras_export
@ -355,3 +356,27 @@ def correct_pad(inputs, kernel_size):
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
return ((correct[0] - adjust[0], correct[0]),
(correct[1] - adjust[1], correct[1]))
def validate_activation(classifier_activation, weights):
"""validates that the classifer_activation is compatible with the weights.
Args:
classifier_activation: str or callable activation function
weights: The pretrained weights to load.
Raises:
ValueError: if an activation other than `None` or `softmax` are used with
pretrained weights.
"""
if weights is None:
return
classifier_activation = activations.get(classifier_activation)
if classifier_activation not in [
activations.get('softmax'),
activations.get(None)
]:
raise ValueError('Only `None` and `softmax` activations are allowed '
'for the `classifier_activation` argument when using '
'pretrained weights, with `include_top=True`')

View File

@ -48,6 +48,7 @@ def InceptionResNetV2(include_top=True,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
**kwargs):
"""Instantiates the Inception-ResNet v2 architecture.
@ -82,14 +83,19 @@ def InceptionResNetV2(include_top=True,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is `True`, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
**kwargs: For backwards compatibility only.
Returns:
A Keras `Model` instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
@ -189,7 +195,9 @@ def InceptionResNetV2(include_top=True,
if include_top:
# Classification block
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -44,12 +44,15 @@ WEIGHTS_PATH_NO_TOP = (
@keras_export('keras.applications.inception_v3.InceptionV3',
'keras.applications.InceptionV3')
def InceptionV3(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def InceptionV3(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the Inception v3 architecture.
Reference paper:
@ -89,13 +92,18 @@ def InceptionV3(include_top=True,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified. Default to 1000.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras `tf.keras.Model` instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
@ -309,7 +317,9 @@ def InceptionV3(include_top=True,
if include_top:
# Classification block
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -90,6 +90,7 @@ def MobileNet(input_shape=None,
input_tensor=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
**kwargs):
"""Instantiates the MobileNet architecture.
@ -138,14 +139,18 @@ def MobileNet(input_shape=None,
classes: Optional number of classes to classify images into, only to be
specified if `include_top` is True, and if no `weights` argument is
specified. Defaults to 1000.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
**kwargs: For backwards compatibility only.
Returns:
A `tf.keras.Model` instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
@ -252,7 +257,9 @@ def MobileNet(input_shape=None,
x = layers.Dropout(dropout, name='dropout')(x)
x = layers.Conv2D(classes, (1, 1), padding='same', name='conv_preds')(x)
x = layers.Reshape((classes,), name='reshape_2')(x)
x = layers.Activation('softmax', name='act_softmax')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Activation(activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -85,7 +85,6 @@ from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/mobilenet_v2/')
@ -99,6 +98,7 @@ def MobileNetV2(input_shape=None,
input_tensor=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
**kwargs):
"""Instantiates the MobileNetV2 architecture.
@ -152,6 +152,9 @@ def MobileNetV2(input_shape=None,
classes: Integer, optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
**kwargs: For backwards compatibility only.
Returns:
@ -161,6 +164,8 @@ def MobileNetV2(input_shape=None,
ValueError: in case of invalid argument for `weights`,
or invalid input shape or invalid alpha, rows when
weights='imagenet'
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
@ -360,9 +365,10 @@ def MobileNetV2(input_shape=None,
if include_top:
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(
classes, activation='softmax', use_bias=True, name='Logits')(
x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -61,18 +61,21 @@ NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5'
NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5'
def NASNet(input_shape=None,
penultimate_filters=4032,
num_blocks=6,
stem_block_filters=96,
skip_reduction=True,
filter_multiplier=2,
include_top=True,
weights=None,
input_tensor=None,
pooling=None,
classes=1000,
default_size=None):
def NASNet(
input_shape=None,
penultimate_filters=4032,
num_blocks=6,
stem_block_filters=96,
skip_reduction=True,
filter_multiplier=2,
include_top=True,
weights=None,
input_tensor=None,
pooling=None,
classes=1000,
default_size=None,
classifier_activation='softmax',
):
"""Instantiates a NASNet model.
Optionally loads weights pre-trained on ImageNet.
@ -127,13 +130,18 @@ def NASNet(input_shape=None,
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
default_size: Specifies the default image size of the model
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: In case of invalid argument for `weights`,
invalid input shape or invalid `penultimate_filters` value.
invalid input shape or invalid `penultimate_filters` value.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
@ -247,7 +255,9 @@ def NASNet(input_shape=None,
if include_top:
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -61,6 +61,7 @@ def ResNet(stack_fn,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
**kwargs):
"""Instantiates the ResNet, ResNetV2, and ResNeXt architecture.
@ -103,14 +104,18 @@ def ResNet(stack_fn,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
**kwargs: For backwards compatibility only.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
@ -167,7 +172,9 @@ def ResNet(stack_fn,
if include_top:
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='probs')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)

View File

@ -25,56 +25,101 @@ from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.applications.resnet_v2.ResNet50V2',
'keras.applications.ResNet50V2')
def ResNet50V2(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def ResNet50V2(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the ResNet50V2 architecture."""
def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2')
x = resnet.stack2(x, 128, 4, name='conv3')
x = resnet.stack2(x, 256, 6, name='conv4')
return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
return resnet.ResNet(stack_fn, True, True, 'resnet50v2', include_top, weights,
input_tensor, input_shape, pooling, classes)
return resnet.ResNet(
stack_fn,
True,
True,
'resnet50v2',
include_top,
weights,
input_tensor,
input_shape,
pooling,
classes,
classifier_activation=classifier_activation,
)
@keras_export('keras.applications.resnet_v2.ResNet101V2',
'keras.applications.ResNet101V2')
def ResNet101V2(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def ResNet101V2(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the ResNet101V2 architecture."""
def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2')
x = resnet.stack2(x, 128, 4, name='conv3')
x = resnet.stack2(x, 256, 23, name='conv4')
return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
return resnet.ResNet(stack_fn, True, True, 'resnet101v2', include_top,
weights, input_tensor, input_shape, pooling, classes)
return resnet.ResNet(
stack_fn,
True,
True,
'resnet101v2',
include_top,
weights,
input_tensor,
input_shape,
pooling,
classes,
classifier_activation=classifier_activation,
)
@keras_export('keras.applications.resnet_v2.ResNet152V2',
'keras.applications.ResNet152V2')
def ResNet152V2(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def ResNet152V2(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the ResNet152V2 architecture."""
def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2')
x = resnet.stack2(x, 128, 8, name='conv3')
x = resnet.stack2(x, 256, 36, name='conv4')
return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
return resnet.ResNet(stack_fn, True, True, 'resnet152v2', include_top,
weights, input_tensor, input_shape, pooling, classes)
return resnet.ResNet(
stack_fn,
True,
True,
'resnet152v2',
include_top,
weights,
input_tensor,
input_shape,
pooling,
classes,
classifier_activation=classifier_activation,
)
@keras_export('keras.applications.resnet_v2.preprocess_input')
@ -123,9 +168,12 @@ DOC = """
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
"""
setattr(ResNet50V2, '__doc__', ResNet50V2.__doc__ + DOC)

View File

@ -37,12 +37,15 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
@keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16')
def VGG16(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def VGG16(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the VGG16 model.
By default, it loads weights pre-trained on ImageNet. Check 'weights' for
@ -85,13 +88,18 @@ def VGG16(include_top=True,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
@ -165,7 +173,10 @@ def VGG16(include_top=True,
x = layers.Flatten(name='flatten')(x)
x = layers.Dense(4096, activation='relu', name='fc1')(x)
x = layers.Dense(4096, activation='relu', name='fc2')(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -42,12 +42,15 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
@keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19')
def VGG19(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def VGG19(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the VGG19 architecture.
By default, it loads weights pre-trained on ImageNet. Check 'weights' for
@ -90,13 +93,18 @@ def VGG19(include_top=True,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
@ -176,7 +184,9 @@ def VGG19(include_top=True,
x = layers.Flatten(name='flatten')(x)
x = layers.Dense(4096, activation='relu', name='fc1')(x)
x = layers.Dense(4096, activation='relu', name='fc2')(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)

View File

@ -48,12 +48,15 @@ TF_WEIGHTS_PATH_NO_TOP = (
@keras_export('keras.applications.xception.Xception',
'keras.applications.Xception')
def Xception(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000):
def Xception(
include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
"""Instantiates the Xception architecture.
Optionally loads weights pre-trained on ImageNet.
@ -90,13 +93,18 @@ def Xception(include_top=True,
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True,
and if no `weights` argument is specified.
classifier_activation: A `str` or callable. The activation function to use
on the "top" layer. Ignored unless `include_top=True`. Set
`classifier_activation=None` to return the logits of the "top" layer.
Returns:
A Keras model instance.
A `keras.Model` instance.
Raises:
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
@ -260,7 +268,9 @@ def Xception(include_top=True,
if include_top:
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='predictions')(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = layers.Dense(classes, activation=classifier_activation,
name='predictions')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)