Update model saving test with MultiWorkerMirroredStrategy.
PiperOrigin-RevId: 315235762 Change-Id: I33a1f08e415d012fd6dff8ad6ac9f97e3ed06b65
This commit is contained in:
parent
92d08074a0
commit
dab7b46024
|
@ -44,7 +44,6 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
|
||||||
maxval=9,
|
maxval=9,
|
||||||
dtype=dtypes.int32)
|
dtype=dtypes.int32)
|
||||||
eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
|
eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
|
||||||
eval_ds = eval_ds.repeat()
|
|
||||||
eval_ds = eval_ds.batch(64, drop_remainder=True)
|
eval_ds = eval_ds.batch(64, drop_remainder=True)
|
||||||
|
|
||||||
return train_ds, eval_ds
|
return train_ds, eval_ds
|
||||||
|
@ -52,21 +51,19 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch):
|
||||||
|
|
||||||
def get_mnist_model(input_shape):
|
def get_mnist_model(input_shape):
|
||||||
"""Define a deterministically-initialized CNN model for MNIST testing."""
|
"""Define a deterministically-initialized CNN model for MNIST testing."""
|
||||||
model = keras.models.Sequential()
|
inputs = keras.Input(shape=input_shape)
|
||||||
model.add(
|
x = keras.layers.Conv2D(
|
||||||
keras.layers.Conv2D(
|
|
||||||
32,
|
32,
|
||||||
kernel_size=(3, 3),
|
kernel_size=(3, 3),
|
||||||
activation="relu",
|
activation="relu",
|
||||||
input_shape=input_shape,
|
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(inputs)
|
||||||
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
|
x = keras.layers.BatchNormalization()(x)
|
||||||
model.add(keras.layers.BatchNormalization())
|
x = keras.layers.Flatten()(x) + keras.layers.Flatten()(x)
|
||||||
model.add(keras.layers.Flatten())
|
x = keras.layers.Dense(
|
||||||
model.add(
|
|
||||||
keras.layers.Dense(
|
|
||||||
10,
|
10,
|
||||||
activation="softmax",
|
activation="softmax",
|
||||||
kernel_initializer=keras.initializers.TruncatedNormal(seed=99)))
|
kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(x)
|
||||||
|
model = keras.Model(inputs=inputs, outputs=x)
|
||||||
|
|
||||||
# TODO(yuefengz): optimizer with slot variables doesn't work because of
|
# TODO(yuefengz): optimizer with slot variables doesn't work because of
|
||||||
# optimizer's bug.
|
# optimizer's bug.
|
||||||
|
|
Loading…
Reference in New Issue