diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 9f63f7bada5..92a46e399a9 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -53,14 +53,14 @@ def fit_distributed(model, steps_per_epoch=None, validation_steps=None): """Fit loop for Distribution Strategies.""" - # TODO(b/122314600): Remove the scope validate. - distributed_training_utils.validate_not_in_strategy_scope() distributed_training_utils.validate_callbacks(callbacks, model.optimizer) distributed_training_utils.validate_inputs( x, y, model._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): + # TODO(b/122314600): Remove the scope validate. + distributed_training_utils.validate_not_in_strategy_scope() steps_per_epoch, batch_size = ( distributed_training_utils.get_input_params( model._distribution_strategy, first_x_value, steps_per_epoch, @@ -138,11 +138,11 @@ def evaluate_distributed(model, steps=None, callbacks=None): """Evaluate loop for Distribution Strategies.""" - # TODO(b/122314600): Remove the scope validate. - distributed_training_utils.validate_not_in_strategy_scope() distributed_training_utils.validate_inputs(x, y, model._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): + # TODO(b/122314600): Remove the scope validate. + distributed_training_utils.validate_not_in_strategy_scope() steps, batch_size = distributed_training_utils.get_input_params( model._distribution_strategy, first_x_value, steps, batch_size) batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) @@ -175,12 +175,12 @@ def predict_distributed(model, steps=None, callbacks=None): """Predict loop for Distribution Strategies.""" - # TODO(b/122314600): Remove the scope validate. - distributed_training_utils.validate_not_in_strategy_scope() distributed_training_utils.validate_inputs( x, None, model._distribution_strategy) first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): + # TODO(b/122314600): Remove the scope validate. + distributed_training_utils.validate_not_in_strategy_scope() steps, batch_size = distributed_training_utils.get_input_params( model._distribution_strategy, first_x_value, steps, batch_size) batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)