Update model saving test with MultiWorkerMirroredStrategy.
PiperOrigin-RevId: 315235762 Change-Id: I33a1f08e415d012fd6dff8ad6ac9f97e3ed06b65
This commit is contained in:
parent
92d08074a0
commit
dab7b46024
|
@ -44,7 +44,6 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
|
|||
maxval=9,
|
||||
dtype=dtypes.int32)
|
||||
eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
|
||||
eval_ds = eval_ds.repeat()
|
||||
eval_ds = eval_ds.batch(64, drop_remainder=True)
|
||||
|
||||
return train_ds, eval_ds
|
||||
|
@ -52,21 +51,19 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
|
|||
|
||||
def get_mnist_model(input_shape):
|
||||
"""Define a deterministically-initialized CNN model for MNIST testing."""
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Conv2D(
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
activation="relu",
|
||||
input_shape=input_shape,
|
||||
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
|
||||
model.add(keras.layers.BatchNormalization())
|
||||
model.add(keras.layers.Flatten())
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
10,
|
||||
activation="softmax",
|
||||
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
|
||||
inputs = keras.Input(shape=input_shape)
|
||||
x = keras.layers.Conv2D(
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
activation="relu",
|
||||
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(inputs)
|
||||
x = keras.layers.BatchNormalization()(x)
|
||||
x = keras.layers.Flatten()(x) + keras.layers.Flatten()(x)
|
||||
x = keras.layers.Dense(
|
||||
10,
|
||||
activation="softmax",
|
||||
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(x)
|
||||
model = keras.Model(inputs=inputs, outputs=x)
|
||||
|
||||
# TODO(yuefengz): optimizer with slot variables doesn't work because of
|
||||
# optimizer's bug.
|
||||
|
|
Loading…
Reference in New Issue