Update model saving test with MultiWorkerMirroredStrategy.

PiperOrigin-RevId: 315235762
Change-Id: I33a1f08e415d012fd6dff8ad6ac9f97e3ed06b65
This commit is contained in:
Priya Gupta 2020-06-08 02:19:00 -07:00 committed by TensorFlower Gardener
parent 92d08074a0
commit dab7b46024
1 changed files with 13 additions and 16 deletions

View File

@ -44,7 +44,6 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
maxval=9, maxval=9,
dtype=dtypes.int32) dtype=dtypes.int32)
eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
eval_ds = eval_ds.batch(64, drop_remainder=True) eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds return train_ds, eval_ds
@ -52,21 +51,19 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
def get_mnist_model(input_shape): def get_mnist_model(input_shape):
"""Define a deterministically-initialized CNN model for MNIST testing.""" """Define a deterministically-initialized CNN model for MNIST testing."""
model = keras.models.Sequential() inputs = keras.Input(shape=input_shape)
model.add( x = keras.layers.Conv2D(
keras.layers.Conv2D( 32,
32, kernel_size=(3, 3),
kernel_size=(3, 3), activation="relu",
activation="relu", kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(inputs)
input_shape=input_shape, x = keras.layers.BatchNormalization()(x)
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))) x = keras.layers.Flatten()(x) + keras.layers.Flatten()(x)
model.add(keras.layers.BatchNormalization()) x = keras.layers.Dense(
model.add(keras.layers.Flatten()) 10,
model.add( activation="softmax",
keras.layers.Dense( kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(x)
10, model = keras.Model(inputs=inputs, outputs=x)
activation="softmax",
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
# TODO(yuefengz): optimizer with slot variables doesn't work because of # TODO(yuefengz): optimizer with slot variables doesn't work because of
# optimizer's bug. # optimizer's bug.