Move the fit-scope check to only error our when using numpy
PiperOrigin-RevId: 228200832
This commit is contained in:
parent
2c2e297e97
commit
68340a6445
@ -53,14 +53,14 @@ def fit_distributed(model,
|
||||
steps_per_epoch=None,
|
||||
validation_steps=None):
|
||||
"""Fit loop for Distribution Strategies."""
|
||||
# TODO(b/122314600): Remove the scope validate.
|
||||
distributed_training_utils.validate_not_in_strategy_scope()
|
||||
distributed_training_utils.validate_callbacks(callbacks, model.optimizer)
|
||||
distributed_training_utils.validate_inputs(
|
||||
x, y, model._distribution_strategy)
|
||||
|
||||
first_x_value = nest.flatten(x)[0]
|
||||
if isinstance(first_x_value, np.ndarray):
|
||||
# TODO(b/122314600): Remove the scope validate.
|
||||
distributed_training_utils.validate_not_in_strategy_scope()
|
||||
steps_per_epoch, batch_size = (
|
||||
distributed_training_utils.get_input_params(
|
||||
model._distribution_strategy, first_x_value, steps_per_epoch,
|
||||
@ -138,11 +138,11 @@ def evaluate_distributed(model,
|
||||
steps=None,
|
||||
callbacks=None):
|
||||
"""Evaluate loop for Distribution Strategies."""
|
||||
# TODO(b/122314600): Remove the scope validate.
|
||||
distributed_training_utils.validate_not_in_strategy_scope()
|
||||
distributed_training_utils.validate_inputs(x, y, model._distribution_strategy)
|
||||
first_x_value = nest.flatten(x)[0]
|
||||
if isinstance(first_x_value, np.ndarray):
|
||||
# TODO(b/122314600): Remove the scope validate.
|
||||
distributed_training_utils.validate_not_in_strategy_scope()
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
model._distribution_strategy, first_x_value, steps, batch_size)
|
||||
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
|
||||
@ -175,12 +175,12 @@ def predict_distributed(model,
|
||||
steps=None,
|
||||
callbacks=None):
|
||||
"""Predict loop for Distribution Strategies."""
|
||||
# TODO(b/122314600): Remove the scope validate.
|
||||
distributed_training_utils.validate_not_in_strategy_scope()
|
||||
distributed_training_utils.validate_inputs(
|
||||
x, None, model._distribution_strategy)
|
||||
first_x_value = nest.flatten(x)[0]
|
||||
if isinstance(first_x_value, np.ndarray):
|
||||
# TODO(b/122314600): Remove the scope validate.
|
||||
distributed_training_utils.validate_not_in_strategy_scope()
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
model._distribution_strategy, first_x_value, steps, batch_size)
|
||||
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user