diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
index 0b93c1d57c6..7fc2b42c919 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_distribution_test.py
@@ -40,9 +40,10 @@ class ImagePreprocessingDistributionTest(
     preprocessing_test_utils.PreprocessingLayerTest):
 
   def test_distribution(self, distribution):
-    np_images = np.random.random((1000, 32, 32, 3)).astype(np.float32)
+    # TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
+    np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
     image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(
-        32, drop_remainder=True)
+        16, drop_remainder=True)
 
     with distribution.scope():
       input_data = keras.Input(shape=(32, 32, 3), dtype=dtypes.float32)
@@ -58,7 +59,7 @@ class ImagePreprocessingDistributionTest(
       output = flatten_layer(preprocessed_image)
       cls_layer = keras.layers.Dense(units=1, activation="sigmoid")
       output = cls_layer(output)
-      model = keras.Model(inputs=input_data, outputs=preprocessed_image)
+      model = keras.Model(inputs=input_data, outputs=output)
     model.compile(loss="binary_crossentropy")
     _ = model.predict(image_dataset)