From dab7b46024395609e1dff662495fa58949a7d9b6 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Mon, 8 Jun 2020 02:19:00 -0700 Subject: [PATCH] Update model saving test with MultiWorkerMirroredStrategy. PiperOrigin-RevId: 315235762 Change-Id: I33a1f08e415d012fd6dff8ad6ac9f97e3ed06b65 --- .../distribute/multi_worker_testing_utils.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/keras/distribute/multi_worker_testing_utils.py b/tensorflow/python/keras/distribute/multi_worker_testing_utils.py index 57e279331ae..4231c32d22e 100644 --- a/tensorflow/python/keras/distribute/multi_worker_testing_utils.py +++ b/tensorflow/python/keras/distribute/multi_worker_testing_utils.py @@ -44,7 +44,6 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch): maxval=9, dtype=dtypes.int32) eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) - eval_ds = eval_ds.repeat() eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds @@ -52,21 +51,19 @@ def mnist_synthetic_dataset(batch_size, steps_per_epoch): def get_mnist_model(input_shape): """Define a deterministically-initialized CNN model for MNIST testing.""" - model = keras.models.Sequential() - model.add( - keras.layers.Conv2D( - 32, - kernel_size=(3, 3), - activation="relu", - input_shape=input_shape, - kernel_initializer=keras.initializers.TruncatedNormal(seed=99))) - model.add(keras.layers.BatchNormalization()) - model.add(keras.layers.Flatten()) - model.add( - keras.layers.Dense( - 10, - activation="softmax", - kernel_initializer=keras.initializers.TruncatedNormal(seed=99))) + inputs = keras.Input(shape=input_shape) + x = keras.layers.Conv2D( + 32, + kernel_size=(3, 3), + activation="relu", + kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(inputs) + x = keras.layers.BatchNormalization()(x) + x = keras.layers.Flatten()(x) + keras.layers.Flatten()(x) + x = keras.layers.Dense( + 10, + activation="softmax", + kernel_initializer=keras.initializers.TruncatedNormal(seed=99))(x) + model = keras.Model(inputs=inputs, outputs=x) # TODO(yuefengz): optimizer with slot variables doesn't work because of # optimizer's bug.