Simplify Keras unit tests by removing unnecessary session scopes and introducing a utility function for repeated code.
PiperOrigin-RevId: 209523944
This commit is contained in:
parent
c5f27df3a5
commit
d29759fa53
@ -235,11 +235,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
num_classes=NUM_CLASSES)
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
|
||||
model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer='rmsprop',
|
||||
@ -298,9 +295,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
test_samples=50,
|
||||
input_shape=(1,),
|
||||
num_classes=NUM_CLASSES)
|
||||
model = keras.models.Sequential((keras.layers.Dense(
|
||||
1, input_dim=1, activation='relu'), keras.layers.Dense(
|
||||
1, activation='sigmoid'),))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=1, num_classes=1, input_dim=1)
|
||||
model.compile(
|
||||
optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
|
||||
|
||||
@ -334,11 +330,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
num_classes=NUM_CLASSES)
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
|
||||
model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer='sgd',
|
||||
@ -388,12 +381,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
def make_model():
|
||||
random_seed.set_random_seed(1234)
|
||||
np.random.seed(1337)
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
|
||||
model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
|
||||
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizers.SGD(lr=0.1),
|
||||
@ -498,12 +487,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
|
||||
def make_model():
|
||||
np.random.seed(1337)
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
|
||||
model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
|
||||
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizers.SGD(lr=0.1),
|
||||
@ -985,9 +970,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
yield x, y
|
||||
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
|
||||
model.add(keras.layers.Dense(10, activation='softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=10, num_classes=10, input_dim=100)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer='sgd',
|
||||
@ -1083,11 +1067,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
|
||||
model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
|
||||
|
||||
@ -1179,7 +1160,6 @@ class KerasCallbacksTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_Tensorboard_eager(self):
|
||||
with self.test_session():
|
||||
temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
|
||||
@ -1191,11 +1171,8 @@ class KerasCallbacksTest(test.TestCase):
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
|
||||
model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='binary_crossentropy',
|
||||
optimizer=adam.AdamOptimizer(0.01),
|
||||
|
@ -25,22 +25,12 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import rmsprop
|
||||
|
||||
|
||||
def _get_small_mlp(num_hidden, num_classes, input_dim=None):
|
||||
model = keras.models.Sequential()
|
||||
if input_dim:
|
||||
model.add(keras.layers.Dense(num_hidden, activation='relu',
|
||||
input_dim=input_dim))
|
||||
else:
|
||||
model.add(keras.layers.Dense(num_hidden, activation='relu'))
|
||||
model.add(keras.layers.Dense(num_classes, activation='softmax'))
|
||||
return model
|
||||
|
||||
|
||||
class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
"""Most Sequential model API tests are covered in `training_test.py`.
|
||||
"""
|
||||
@ -63,7 +53,8 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
batch_size = 5
|
||||
num_classes = 2
|
||||
|
||||
model = _get_small_mlp(num_hidden, num_classes, input_dim)
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden, num_classes, input_dim)
|
||||
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
|
||||
x = np.random.random((batch_size, input_dim))
|
||||
y = np.random.random((batch_size, num_classes))
|
||||
@ -94,7 +85,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
batch_size = 5
|
||||
num_classes = 2
|
||||
|
||||
model = _get_small_mlp(num_hidden, num_classes)
|
||||
model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer=rmsprop.RMSPropOptimizer(1e-3),
|
||||
@ -118,7 +109,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
num_samples = 50
|
||||
steps_per_epoch = 10
|
||||
|
||||
model = _get_small_mlp(num_hidden, num_classes)
|
||||
model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer=rmsprop.RMSPropOptimizer(1e-3),
|
||||
@ -145,9 +136,9 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def get_model():
|
||||
if deferred:
|
||||
model = _get_small_mlp(10, 4)
|
||||
model = testing_utils.get_small_sequential_mlp(10, 4)
|
||||
else:
|
||||
model = _get_small_mlp(10, 4, input_dim=3)
|
||||
model = testing_utils.get_small_sequential_mlp(10, 4, input_dim=3)
|
||||
model.compile(
|
||||
optimizer=rmsprop.RMSPropOptimizer(1e-3),
|
||||
loss='categorical_crossentropy',
|
||||
@ -262,7 +253,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
batch_size = 5
|
||||
num_classes = 2
|
||||
|
||||
model = _get_small_mlp(num_hidden, num_classes)
|
||||
model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer=rmsprop.RMSPropOptimizer(1e-3),
|
||||
@ -284,21 +275,21 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_shape_inference_deferred(self):
|
||||
model = _get_small_mlp(4, 5)
|
||||
model = testing_utils.get_small_sequential_mlp(4, 5)
|
||||
output_shape = model.compute_output_shape((None, 7))
|
||||
self.assertEqual(tuple(output_shape.as_list()), (None, 5))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_build_deferred(self):
|
||||
model = _get_small_mlp(4, 5)
|
||||
model = testing_utils.get_small_sequential_mlp(4, 5)
|
||||
|
||||
model.build((None, 10))
|
||||
self.assertTrue(model.built)
|
||||
self.assertEqual(len(model.weights), 4)
|
||||
|
||||
# Test with nested model
|
||||
model = _get_small_mlp(4, 3)
|
||||
inner_model = _get_small_mlp(4, 5)
|
||||
model = testing_utils.get_small_sequential_mlp(4, 3)
|
||||
inner_model = testing_utils.get_small_sequential_mlp(4, 5)
|
||||
model.add(inner_model)
|
||||
|
||||
model.build((None, 10))
|
||||
@ -308,8 +299,8 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_nesting(self):
|
||||
model = _get_small_mlp(4, 3)
|
||||
inner_model = _get_small_mlp(4, 5)
|
||||
model = testing_utils.get_small_sequential_mlp(4, 3)
|
||||
inner_model = testing_utils.get_small_sequential_mlp(4, 5)
|
||||
model.add(inner_model)
|
||||
|
||||
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
|
||||
@ -353,7 +344,7 @@ class TestSequentialEagerIntegration(test.TestCase):
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_build_before_fit(self):
|
||||
# Fix for b/112433577
|
||||
model = _get_small_mlp(4, 5)
|
||||
model = testing_utils.get_small_sequential_mlp(4, 5)
|
||||
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
|
||||
|
||||
model.build((None, 6))
|
||||
|
@ -49,7 +49,6 @@ class TrainingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_fit_on_arrays(self):
|
||||
with self.test_session():
|
||||
a = keras.layers.Input(shape=(3,), name='input_a')
|
||||
b = keras.layers.Input(shape=(3,), name='input_b')
|
||||
|
||||
@ -253,7 +252,6 @@ class TrainingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_evaluate_predict_on_arrays(self):
|
||||
with self.test_session():
|
||||
a = keras.layers.Input(shape=(3,), name='input_a')
|
||||
b = keras.layers.Input(shape=(3,), name='input_b')
|
||||
|
||||
@ -340,12 +338,8 @@ class TrainingTest(test.TestCase):
|
||||
test_samples = 1000
|
||||
input_dim = 5
|
||||
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
|
||||
model.add(keras.layers.Activation('relu'))
|
||||
model.add(keras.layers.Dense(num_classes))
|
||||
model.add(keras.layers.Activation('softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=10, num_classes=num_classes, input_dim=input_dim)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
model.compile(optimizer, loss='categorical_crossentropy')
|
||||
np.random.seed(1337)
|
||||
@ -468,12 +462,8 @@ class LossWeightingTest(test.TestCase):
|
||||
input_dim = 5
|
||||
learning_rate = 0.001
|
||||
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
|
||||
model.add(keras.layers.Activation('relu'))
|
||||
model.add(keras.layers.Dense(num_classes))
|
||||
model.add(keras.layers.Activation('softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=10, num_classes=num_classes, input_dim=input_dim)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['acc'],
|
||||
@ -541,12 +531,8 @@ class LossWeightingTest(test.TestCase):
|
||||
input_dim = 5
|
||||
learning_rate = 0.001
|
||||
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
|
||||
model.add(keras.layers.Activation('relu'))
|
||||
model.add(keras.layers.Dense(num_classes))
|
||||
model.add(keras.layers.Activation('softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=10, num_classes=num_classes, input_dim=input_dim)
|
||||
model.compile(
|
||||
RMSPropOptimizer(learning_rate=learning_rate),
|
||||
metrics=['acc'],
|
||||
@ -1909,11 +1895,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_training_and_eval_methods_on_iterators_single_io(self):
|
||||
with self.test_session():
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae', metrics_module.CategoricalAccuracy()]
|
||||
@ -1975,11 +1957,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_get_next_op_created_once(self):
|
||||
with self.test_session():
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
@ -2000,11 +1978,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_iterators_running_out_of_data(self):
|
||||
with self.test_session():
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
@ -2028,11 +2002,7 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_calling_model_on_same_dataset(self):
|
||||
with self.test_session():
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
@ -2055,11 +2025,7 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_training_and_eval_methods_on_dataset(self):
|
||||
with self.test_session():
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae', metrics_module.CategoricalAccuracy()]
|
||||
@ -2119,13 +2085,8 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
|
||||
def test_dataset_input_shape_validation(self):
|
||||
with self.test_session():
|
||||
x = keras.layers.Input(shape=(3,), name='input')
|
||||
y = keras.layers.Dense(4, name='dense')(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
model.compile(optimizer, loss)
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
|
||||
|
||||
# User forgets to batch the dataset
|
||||
inputs = np.zeros((10, 3))
|
||||
@ -2134,7 +2095,7 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
dataset = dataset.repeat(100)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'expected input to have 2 dimensions'):
|
||||
r'expected (.*?) to have 2 dimensions'):
|
||||
model.train_on_batch(dataset)
|
||||
|
||||
# Wrong input shape
|
||||
@ -2145,7 +2106,7 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
dataset = dataset.batch(10)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'expected input to have shape'):
|
||||
r'expected (.*?) to have shape \(3,\)'):
|
||||
model.train_on_batch(dataset)
|
||||
|
||||
|
||||
@ -2176,7 +2137,6 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_metrics_correctness(self):
|
||||
with self.test_session():
|
||||
model = keras.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
@ -2203,7 +2163,6 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_metrics_correctness_with_iterator(self):
|
||||
with self.test_session():
|
||||
model = keras.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
@ -2237,7 +2196,6 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_metrics_correctness_with_weighted_metrics(self):
|
||||
with self.test_session():
|
||||
np.random.seed(1337)
|
||||
x = np.array([[[1.], [1.]], [[0.], [0.]]])
|
||||
model = keras.models.Sequential()
|
||||
@ -2266,7 +2224,6 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_metric_state_reset_between_fit_and_evaluate(self):
|
||||
with self.test_session():
|
||||
model = keras.Sequential()
|
||||
model.add(keras.layers.Dense(3, activation='relu', input_dim=4))
|
||||
model.add(keras.layers.Dense(1, activation='sigmoid'))
|
||||
@ -2291,11 +2248,8 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
num_classes = 5
|
||||
input_dim = 5
|
||||
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(10, activation='relu', input_shape=(input_dim,)))
|
||||
model.add(keras.layers.Dense(num_classes, activation='softmax'))
|
||||
model = testing_utils.get_small_sequential_mlp(
|
||||
num_hidden=10, num_classes=num_classes, input_dim=input_dim)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'Type of `metrics` argument not understood. '
|
||||
|
@ -184,3 +184,22 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
|
||||
# for further checks in the caller function
|
||||
return actual_output
|
||||
|
||||
|
||||
def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None):
|
||||
model = keras.models.Sequential()
|
||||
if input_dim:
|
||||
model.add(keras.layers.Dense(num_hidden, activation='relu',
|
||||
input_dim=input_dim))
|
||||
else:
|
||||
model.add(keras.layers.Dense(num_hidden, activation='relu'))
|
||||
activation = 'sigmoid' if num_classes == 1 else 'softmax'
|
||||
model.add(keras.layers.Dense(num_classes, activation=activation))
|
||||
return model
|
||||
|
||||
|
||||
def get_small_functional_mlp(num_hidden, num_classes, input_dim):
|
||||
inputs = keras.Input(shape=(input_dim,))
|
||||
outputs = keras.layers.Dense(num_hidden, activation='relu')(inputs)
|
||||
activation = 'sigmoid' if num_classes == 1 else 'softmax'
|
||||
outputs = keras.layers.Dense(num_classes, activation=activation)(outputs)
|
||||
return keras.Model(inputs, outputs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user