diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py index 5a9384bb7e0..b680960429c 100644 --- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py +++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py @@ -251,6 +251,33 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase): train_step(input_iterator) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, + mode=["eager"])) + def test_model_predict_with_dynamic_batch(self, distribution): + input_data = np.random.random([1, 32, 64, 64, 3]) + input_shape = tuple(input_data.shape[1:]) + + def build_model(): + model = keras.models.Sequential() + model.add( + keras.layers.ConvLSTM2D( + 4, + kernel_size=(4, 4), + activation="sigmoid", + padding="same", + input_shape=input_shape)) + model.add(keras.layers.GlobalMaxPooling2D()) + model.add(keras.layers.Dense(2, activation="sigmoid")) + return model + + with distribution.scope(): + model = build_model() + model.compile(loss="binary_crossentropy", optimizer="adam") + result = model.predict(input_data) + self.assertEqual(result.shape, (1, 2)) + @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies,