From 51c182c4a35e55b969f9934f9ba85d840cfb4b92 Mon Sep 17 00:00:00 2001
From: Mark Daoust <markdaoust@google.com>
Date: Thu, 20 Feb 2020 10:01:00 -0800
Subject: [PATCH] Add a `classifier_activation` option to keras.applications

Defaults to `softmax` (the current behavior) but now users have the option of turning it off, or setting a different activation function.

PiperOrigin-RevId: 296235461
Change-Id: Iea01b8a2b260ece5e91b4dc1e6e42526136a2066
---
 .../python/keras/applications/densenet.py     | 29 ++++--
 .../python/keras/applications/efficientnet.py | 45 +++++----
 .../keras/applications/imagenet_utils.py      | 25 +++++
 .../keras/applications/inception_resnet_v2.py | 12 ++-
 .../python/keras/applications/inception_v3.py | 26 +++--
 .../python/keras/applications/mobilenet.py    | 13 ++-
 .../python/keras/applications/mobilenet_v2.py | 14 ++-
 .../python/keras/applications/nasnet.py       | 40 +++++---
 .../python/keras/applications/resnet.py       | 13 ++-
 .../python/keras/applications/resnet_v2.py    | 98 ++++++++++++++-----
 tensorflow/python/keras/applications/vgg16.py | 31 ++++--
 tensorflow/python/keras/applications/vgg19.py | 26 +++--
 .../python/keras/applications/xception.py     | 26 +++--
 13 files changed, 285 insertions(+), 113 deletions(-)

diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py
index 237202ff429..9a7be9a3b7a 100644
--- a/tensorflow/python/keras/applications/densenet.py
+++ b/tensorflow/python/keras/applications/densenet.py
@@ -125,13 +125,16 @@ def conv_block(x, growth_rate, name):
   return x
 
 
-def DenseNet(blocks,
-             include_top=True,
-             weights='imagenet',
-             input_tensor=None,
-             input_shape=None,
-             pooling=None,
-             classes=1000):
+def DenseNet(
+    blocks,
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the DenseNet architecture.
 
   Optionally loads weights pre-trained on ImageNet.
@@ -169,13 +172,18 @@ def DenseNet(blocks,
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if not (weights in {'imagenet', None} or os.path.exists(weights)):
     raise ValueError('The `weights` argument should be either '
@@ -228,7 +236,10 @@ def DenseNet(blocks,
 
   if include_top:
     x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
-    x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
+
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py
index f3d0f1e5b0e..11ba3a98b7e 100644
--- a/tensorflow/python/keras/applications/efficientnet.py
+++ b/tensorflow/python/keras/applications/efficientnet.py
@@ -141,21 +141,24 @@ DENSE_KERNEL_INITIALIZER = {
 }
 
 
-def EfficientNet(width_coefficient,
-                 depth_coefficient,
-                 default_size,
-                 dropout_rate=0.2,
-                 drop_connect_rate=0.2,
-                 depth_divisor=8,
-                 activation='swish',
-                 blocks_args='default',
-                 model_name='efficientnet',
-                 include_top=True,
-                 weights='imagenet',
-                 input_tensor=None,
-                 input_shape=None,
-                 pooling=None,
-                 classes=1000):
+def EfficientNet(
+    width_coefficient,
+    depth_coefficient,
+    default_size,
+    dropout_rate=0.2,
+    drop_connect_rate=0.2,
+    depth_divisor=8,
+    activation='swish',
+    blocks_args='default',
+    model_name='efficientnet',
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the EfficientNet architecture using given scaling coefficients.
 
   Optionally loads weights pre-trained on ImageNet.
@@ -197,13 +200,18 @@ def EfficientNet(width_coefficient,
     classes: optional number of classes to classify images
         into, only to be specified if `include_top` is True, and
         if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+        on the "top" layer. Ignored unless `include_top=True`. Set
+        `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if blocks_args == 'default':
     blocks_args = DEFAULT_BLOCKS_ARGS
@@ -307,11 +315,12 @@ def EfficientNet(width_coefficient,
     x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
     if dropout_rate > 0:
       x = layers.Dropout(dropout_rate, name='top_dropout')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
     x = layers.Dense(
         classes,
-        activation='softmax',
+        activation=classifier_activation,
         kernel_initializer=DENSE_KERNEL_INITIALIZER,
-        name='probs')(x)
+        name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
diff --git a/tensorflow/python/keras/applications/imagenet_utils.py b/tensorflow/python/keras/applications/imagenet_utils.py
index 206be8406ee..55299ebfa50 100644
--- a/tensorflow/python/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/applications/imagenet_utils.py
@@ -22,6 +22,7 @@ import warnings
 
 import numpy as np
 
+from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.util.tf_export import keras_export
@@ -355,3 +356,27 @@ def correct_pad(inputs, kernel_size):
   correct = (kernel_size[0] // 2, kernel_size[1] // 2)
   return ((correct[0] - adjust[0], correct[0]),
           (correct[1] - adjust[1], correct[1]))
+
+
+def validate_activation(classifier_activation, weights):
+  """validates that the classifer_activation is compatible with the weights.
+
+  Args:
+    classifier_activation: str or callable activation function
+    weights: The pretrained weights to load.
+
+  Raises:
+    ValueError: if an activation other than `None` or `softmax` are used with
+      pretrained weights.
+  """
+  if weights is None:
+    return
+
+  classifier_activation = activations.get(classifier_activation)
+  if classifier_activation not in [
+      activations.get('softmax'),
+      activations.get(None)
+  ]:
+    raise ValueError('Only `None` and `softmax` activations are allowed '
+                     'for the `classifier_activation` argument when using '
+                     'pretrained weights, with `include_top=True`')
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index 092343144c7..ab8ab71e3b0 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -48,6 +48,7 @@ def InceptionResNetV2(include_top=True,
                       input_shape=None,
                       pooling=None,
                       classes=1000,
+                      classifier_activation='softmax',
                       **kwargs):
   """Instantiates the Inception-ResNet v2 architecture.
 
@@ -82,14 +83,19 @@ def InceptionResNetV2(include_top=True,
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is `True`, and
       if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
     **kwargs: For backwards compatibility only.
 
   Returns:
-    A Keras `Model` instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if 'layers' in kwargs:
     global layers
@@ -189,7 +195,9 @@ def InceptionResNetV2(include_top=True,
   if include_top:
     # Classification block
     x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
-    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py
index ecec195dff6..f8a56e62234 100644
--- a/tensorflow/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/applications/inception_v3.py
@@ -44,12 +44,15 @@ WEIGHTS_PATH_NO_TOP = (
 
 @keras_export('keras.applications.inception_v3.InceptionV3',
               'keras.applications.InceptionV3')
-def InceptionV3(include_top=True,
-                weights='imagenet',
-                input_tensor=None,
-                input_shape=None,
-                pooling=None,
-                classes=1000):
+def InceptionV3(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the Inception v3 architecture.
 
   Reference paper:
@@ -89,13 +92,18 @@ def InceptionV3(include_top=True,
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified. Default to 1000.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras `tf.keras.Model` instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if not (weights in {'imagenet', None} or os.path.exists(weights)):
     raise ValueError('The `weights` argument should be either '
@@ -309,7 +317,9 @@ def InceptionV3(include_top=True,
   if include_top:
     # Classification block
     x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
-    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index e64efa53815..224e8c84496 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -90,6 +90,7 @@ def MobileNet(input_shape=None,
               input_tensor=None,
               pooling=None,
               classes=1000,
+              classifier_activation='softmax',
               **kwargs):
   """Instantiates the MobileNet architecture.
 
@@ -138,14 +139,18 @@ def MobileNet(input_shape=None,
     classes: Optional number of classes to classify images into, only to be
       specified if `include_top` is True, and if no `weights` argument is
       specified. Defaults to 1000.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
     **kwargs: For backwards compatibility only.
-
   Returns:
-    A `tf.keras.Model` instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if 'layers' in kwargs:
     global layers
@@ -252,7 +257,9 @@ def MobileNet(input_shape=None,
     x = layers.Dropout(dropout, name='dropout')(x)
     x = layers.Conv2D(classes, (1, 1), padding='same', name='conv_preds')(x)
     x = layers.Reshape((classes,), name='reshape_2')(x)
-    x = layers.Activation('softmax', name='act_softmax')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Activation(activation=classifier_activation,
+                          name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py
index 186b6e3db61..a983f6d7e46 100644
--- a/tensorflow/python/keras/applications/mobilenet_v2.py
+++ b/tensorflow/python/keras/applications/mobilenet_v2.py
@@ -85,7 +85,6 @@ from tensorflow.python.keras.utils import layer_utils
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util.tf_export import keras_export
 
-
 BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
                     'keras-applications/mobilenet_v2/')
 
@@ -99,6 +98,7 @@ def MobileNetV2(input_shape=None,
                 input_tensor=None,
                 pooling=None,
                 classes=1000,
+                classifier_activation='softmax',
                 **kwargs):
   """Instantiates the MobileNetV2 architecture.
 
@@ -152,6 +152,9 @@ def MobileNetV2(input_shape=None,
     classes: Integer, optional number of classes to classify images
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
     **kwargs: For backwards compatibility only.
 
   Returns:
@@ -161,6 +164,8 @@ def MobileNetV2(input_shape=None,
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape or invalid alpha, rows when
       weights='imagenet'
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if 'layers' in kwargs:
     global layers
@@ -360,9 +365,10 @@ def MobileNetV2(input_shape=None,
 
   if include_top:
     x = layers.GlobalAveragePooling2D()(x)
-    x = layers.Dense(
-        classes, activation='softmax', use_bias=True, name='Logits')(
-            x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
+
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py
index 0a693b83652..a29d5f4c380 100644
--- a/tensorflow/python/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/applications/nasnet.py
@@ -61,18 +61,21 @@ NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5'
 NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5'
 
 
-def NASNet(input_shape=None,
-           penultimate_filters=4032,
-           num_blocks=6,
-           stem_block_filters=96,
-           skip_reduction=True,
-           filter_multiplier=2,
-           include_top=True,
-           weights=None,
-           input_tensor=None,
-           pooling=None,
-           classes=1000,
-           default_size=None):
+def NASNet(
+    input_shape=None,
+    penultimate_filters=4032,
+    num_blocks=6,
+    stem_block_filters=96,
+    skip_reduction=True,
+    filter_multiplier=2,
+    include_top=True,
+    weights=None,
+    input_tensor=None,
+    pooling=None,
+    classes=1000,
+    default_size=None,
+    classifier_activation='softmax',
+):
   """Instantiates a NASNet model.
 
   Optionally loads weights pre-trained on ImageNet.
@@ -127,13 +130,18 @@ def NASNet(input_shape=None,
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified.
     default_size: Specifies the default image size of the model
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: In case of invalid argument for `weights`,
-        invalid input shape or invalid `penultimate_filters` value.
+      invalid input shape or invalid `penultimate_filters` value.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if not (weights in {'imagenet', None} or os.path.exists(weights)):
     raise ValueError('The `weights` argument should be either '
@@ -247,7 +255,9 @@ def NASNet(input_shape=None,
 
   if include_top:
     x = layers.GlobalAveragePooling2D()(x)
-    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/resnet.py b/tensorflow/python/keras/applications/resnet.py
index d30b3cca55e..86d26695373 100644
--- a/tensorflow/python/keras/applications/resnet.py
+++ b/tensorflow/python/keras/applications/resnet.py
@@ -61,6 +61,7 @@ def ResNet(stack_fn,
            input_shape=None,
            pooling=None,
            classes=1000,
+           classifier_activation='softmax',
            **kwargs):
   """Instantiates the ResNet, ResNetV2, and ResNeXt architecture.
 
@@ -103,14 +104,18 @@ def ResNet(stack_fn,
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
     **kwargs: For backwards compatibility only.
-
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if 'layers' in kwargs:
     global layers
@@ -167,7 +172,9 @@ def ResNet(stack_fn,
 
   if include_top:
     x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
-    x = layers.Dense(classes, activation='softmax', name='probs')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
diff --git a/tensorflow/python/keras/applications/resnet_v2.py b/tensorflow/python/keras/applications/resnet_v2.py
index ce56fbb19cb..2e31017dfa9 100644
--- a/tensorflow/python/keras/applications/resnet_v2.py
+++ b/tensorflow/python/keras/applications/resnet_v2.py
@@ -25,56 +25,101 @@ from tensorflow.python.util.tf_export import keras_export
 
 @keras_export('keras.applications.resnet_v2.ResNet50V2',
               'keras.applications.ResNet50V2')
-def ResNet50V2(include_top=True,
-               weights='imagenet',
-               input_tensor=None,
-               input_shape=None,
-               pooling=None,
-               classes=1000):
+def ResNet50V2(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the ResNet50V2 architecture."""
   def stack_fn(x):
     x = resnet.stack2(x, 64, 3, name='conv2')
     x = resnet.stack2(x, 128, 4, name='conv3')
     x = resnet.stack2(x, 256, 6, name='conv4')
     return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
-  return resnet.ResNet(stack_fn, True, True, 'resnet50v2', include_top, weights,
-                       input_tensor, input_shape, pooling, classes)
+
+  return resnet.ResNet(
+      stack_fn,
+      True,
+      True,
+      'resnet50v2',
+      include_top,
+      weights,
+      input_tensor,
+      input_shape,
+      pooling,
+      classes,
+      classifier_activation=classifier_activation,
+  )
 
 
 @keras_export('keras.applications.resnet_v2.ResNet101V2',
               'keras.applications.ResNet101V2')
-def ResNet101V2(include_top=True,
-                weights='imagenet',
-                input_tensor=None,
-                input_shape=None,
-                pooling=None,
-                classes=1000):
+def ResNet101V2(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the ResNet101V2 architecture."""
   def stack_fn(x):
     x = resnet.stack2(x, 64, 3, name='conv2')
     x = resnet.stack2(x, 128, 4, name='conv3')
     x = resnet.stack2(x, 256, 23, name='conv4')
     return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
-  return resnet.ResNet(stack_fn, True, True, 'resnet101v2', include_top,
-                       weights, input_tensor, input_shape, pooling, classes)
+
+  return resnet.ResNet(
+      stack_fn,
+      True,
+      True,
+      'resnet101v2',
+      include_top,
+      weights,
+      input_tensor,
+      input_shape,
+      pooling,
+      classes,
+      classifier_activation=classifier_activation,
+  )
 
 
 @keras_export('keras.applications.resnet_v2.ResNet152V2',
               'keras.applications.ResNet152V2')
-def ResNet152V2(include_top=True,
-                weights='imagenet',
-                input_tensor=None,
-                input_shape=None,
-                pooling=None,
-                classes=1000):
+def ResNet152V2(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the ResNet152V2 architecture."""
   def stack_fn(x):
     x = resnet.stack2(x, 64, 3, name='conv2')
     x = resnet.stack2(x, 128, 8, name='conv3')
     x = resnet.stack2(x, 256, 36, name='conv4')
     return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
-  return resnet.ResNet(stack_fn, True, True, 'resnet152v2', include_top,
-                       weights, input_tensor, input_shape, pooling, classes)
+
+  return resnet.ResNet(
+      stack_fn,
+      True,
+      True,
+      'resnet152v2',
+      include_top,
+      weights,
+      input_tensor,
+      input_shape,
+      pooling,
+      classes,
+      classifier_activation=classifier_activation,
+  )
 
 
 @keras_export('keras.applications.resnet_v2.preprocess_input')
@@ -123,9 +168,12 @@ DOC = """
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 """
 
 setattr(ResNet50V2, '__doc__', ResNet50V2.__doc__ + DOC)
diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py
index 958ed955106..e268a592833 100644
--- a/tensorflow/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/applications/vgg16.py
@@ -37,12 +37,15 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
 
 
 @keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16')
-def VGG16(include_top=True,
-          weights='imagenet',
-          input_tensor=None,
-          input_shape=None,
-          pooling=None,
-          classes=1000):
+def VGG16(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the VGG16 model.
 
   By default, it loads weights pre-trained on ImageNet. Check 'weights' for
@@ -85,13 +88,18 @@ def VGG16(include_top=True,
       classes: optional number of classes to classify images
           into, only to be specified if `include_top` is True, and
           if no `weights` argument is specified.
+      classifier_activation: A `str` or callable. The activation function to use
+          on the "top" layer. Ignored unless `include_top=True`. Set
+          `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-      A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
-      ValueError: in case of invalid argument for `weights`,
-          or invalid input shape.
+    ValueError: in case of invalid argument for `weights`,
+      or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if not (weights in {'imagenet', None} or os.path.exists(weights)):
     raise ValueError('The `weights` argument should be either '
@@ -165,7 +173,10 @@ def VGG16(include_top=True,
     x = layers.Flatten(name='flatten')(x)
     x = layers.Dense(4096, activation='relu', name='fc1')(x)
     x = layers.Dense(4096, activation='relu', name='fc2')(x)
-    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
+
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py
index 808580ada07..8d25dc0e42f 100644
--- a/tensorflow/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/applications/vgg19.py
@@ -42,12 +42,15 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
 
 
 @keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19')
-def VGG19(include_top=True,
-          weights='imagenet',
-          input_tensor=None,
-          input_shape=None,
-          pooling=None,
-          classes=1000):
+def VGG19(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the VGG19 architecture.
 
   By default, it loads weights pre-trained on ImageNet. Check 'weights' for
@@ -90,13 +93,18 @@ def VGG19(include_top=True,
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is True, and
       if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if not (weights in {'imagenet', None} or os.path.exists(weights)):
     raise ValueError('The `weights` argument should be either '
@@ -176,7 +184,9 @@ def VGG19(include_top=True,
     x = layers.Flatten(name='flatten')(x)
     x = layers.Dense(4096, activation='relu', name='fc1')(x)
     x = layers.Dense(4096, activation='relu', name='fc2')(x)
-    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)
diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py
index 47f386cc721..7f6602b90d1 100644
--- a/tensorflow/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/applications/xception.py
@@ -48,12 +48,15 @@ TF_WEIGHTS_PATH_NO_TOP = (
 
 @keras_export('keras.applications.xception.Xception',
               'keras.applications.Xception')
-def Xception(include_top=True,
-             weights='imagenet',
-             input_tensor=None,
-             input_shape=None,
-             pooling=None,
-             classes=1000):
+def Xception(
+    include_top=True,
+    weights='imagenet',
+    input_tensor=None,
+    input_shape=None,
+    pooling=None,
+    classes=1000,
+    classifier_activation='softmax',
+):
   """Instantiates the Xception architecture.
 
   Optionally loads weights pre-trained on ImageNet.
@@ -90,13 +93,18 @@ def Xception(include_top=True,
     classes: optional number of classes to classify images
       into, only to be specified if `include_top` is True,
       and if no `weights` argument is specified.
+    classifier_activation: A `str` or callable. The activation function to use
+      on the "top" layer. Ignored unless `include_top=True`. Set
+      `classifier_activation=None` to return the logits of the "top" layer.
 
   Returns:
-    A Keras model instance.
+    A `keras.Model` instance.
 
   Raises:
     ValueError: in case of invalid argument for `weights`,
       or invalid input shape.
+    ValueError: if `classifier_activation` is not `softmax` or `None` when
+      using a pretrained top layer.
   """
   if not (weights in {'imagenet', None} or os.path.exists(weights)):
     raise ValueError('The `weights` argument should be either '
@@ -260,7 +268,9 @@ def Xception(include_top=True,
 
   if include_top:
     x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
-    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
+    imagenet_utils.validate_activation(classifier_activation, weights)
+    x = layers.Dense(classes, activation=classifier_activation,
+                     name='predictions')(x)
   else:
     if pooling == 'avg':
       x = layers.GlobalAveragePooling2D()(x)