Update model saving test with MultiWorkerMirroredStrategy.

PiperOrigin-RevId: 315235762
Change-Id: I33a1f08e415d012fd6dff8ad6ac9f97e3ed06b65
This commit is contained in:
Priya Gupta 2020-06-08 02:19:00 -07:00 committed by TensorFlower Gardener
parent 92d08074a0
commit dab7b46024
1 changed files with 13 additions and 16 deletions

View File

@ -44,7 +44,6 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
maxval=9,
dtype=dtypes.int32)
eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds
@ -52,21 +51,19 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
def get_mnist_model(input_shape):
"""Define a deterministically-initialized CNN model for MNIST testing."""
model = keras.models.Sequential()
model.add(
keras.layers.Conv2D(
32,
kernel_size=(3, 3),
activation="relu",
input_shape=input_shape,
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Flatten())
model.add(
keras.layers.Dense(
10,
activation="softmax",
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
inputs = keras.Input(shape=input_shape)
x = keras.layers.Conv2D(
32,
kernel_size=(3, 3),
activation="relu",
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Flatten()(x) + keras.layers.Flatten()(x)
x = keras.layers.Dense(
10,
activation="softmax",
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(x)
model = keras.Model(inputs=inputs, outputs=x)
# TODO(yuefengz): optimizer with slot variables doesn't work because of
# optimizer's bug.