Move the fit-scope check to only error our when using numpy

PiperOrigin-RevId: 228200832
This commit is contained in:
Sourabh Bajaj 2019-01-07 11:20:59 -08:00 committed by TensorFlower Gardener
parent 2c2e297e97
commit 68340a6445

View File

@ -53,14 +53,14 @@ def fit_distributed(model,
steps_per_epoch=None,
validation_steps=None):
"""Fit loop for Distribution Strategies."""
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
distributed_training_utils.validate_callbacks(callbacks, model.optimizer)
distributed_training_utils.validate_inputs(
x, y, model._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
steps_per_epoch, batch_size = (
distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps_per_epoch,
@ -138,11 +138,11 @@ def evaluate_distributed(model,
steps=None,
callbacks=None):
"""Evaluate loop for Distribution Strategies."""
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
distributed_training_utils.validate_inputs(x, y, model._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
steps, batch_size = distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps, batch_size)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
@ -175,12 +175,12 @@ def predict_distributed(model,
steps=None,
callbacks=None):
"""Predict loop for Distribution Strategies."""
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
distributed_training_utils.validate_inputs(
x, None, model._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
steps, batch_size = distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps, batch_size)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)