Move v1-specific training_utils to training_utils_v1.py
PiperOrigin-RevId: 333612671 Change-Id: I63e2058f4c2053fd5d4ec1bcca0ad4810ddcb136
This commit is contained in:
parent
df2faf57a0
commit
949f8ea44a
@ -245,8 +245,8 @@ class TensorBoard(callbacks.TensorBoard):
|
|||||||
# visualize embeddings.
|
# visualize embeddings.
|
||||||
if self.embeddings_freq and self.embeddings_data is not None:
|
if self.embeddings_freq and self.embeddings_data is not None:
|
||||||
# Avoid circular dependency.
|
# Avoid circular dependency.
|
||||||
from tensorflow.python.keras.engine import training_utils_v1 # pylint: disable=g-import-not-at-top
|
from tensorflow.python.keras.engine import training_utils # pylint: disable=g-import-not-at-top
|
||||||
self.embeddings_data = training_utils_v1.standardize_input_data(
|
self.embeddings_data = training_utils.standardize_input_data(
|
||||||
self.embeddings_data, model.input_names)
|
self.embeddings_data, model.input_names)
|
||||||
|
|
||||||
# If embedding_layer_names are not provided, get all of the embedding
|
# If embedding_layer_names are not provided, get all of the embedding
|
||||||
|
@ -36,9 +36,9 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import callbacks
|
from tensorflow.python.keras import callbacks
|
||||||
from tensorflow.python.keras import metrics as metrics_module
|
from tensorflow.python.keras import metrics as metrics_module
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizer_v1
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
from tensorflow.python.keras.utils import tf_contextlib
|
from tensorflow.python.keras.utils import tf_contextlib
|
||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
@ -639,9 +639,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
# TODO(b/124535720): Remove once this standarize data logic is shared with
|
# TODO(b/124535720): Remove once this standarize data logic is shared with
|
||||||
# main flow.
|
# main flow.
|
||||||
inputs, targets = nest.map_structure(
|
inputs, targets = nest.map_structure(
|
||||||
training_utils_v1.standardize_single_array, (inputs, targets))
|
training_utils.standardize_single_array, (inputs, targets))
|
||||||
else:
|
else:
|
||||||
inputs = training_utils_v1.ModelInputs(inputs).as_list()
|
inputs = training_utils.ModelInputs(inputs).as_list()
|
||||||
|
|
||||||
if mode == ModeKeys.PREDICT:
|
if mode == ModeKeys.PREDICT:
|
||||||
sample_weights = []
|
sample_weights = []
|
||||||
@ -779,7 +779,7 @@ def _clone_and_build_model(model, mode, inputs=None, targets=None):
|
|||||||
cloned_model = models.clone_model(model, input_tensors=inputs)
|
cloned_model = models.clone_model(model, input_tensors=inputs)
|
||||||
|
|
||||||
# Compile and build model.
|
# Compile and build model.
|
||||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
|
||||||
optimizer = model.optimizer
|
optimizer = model.optimizer
|
||||||
else:
|
else:
|
||||||
optimizer_config = model.optimizer.get_config()
|
optimizer_config = model.optimizer.get_config()
|
||||||
|
@ -38,7 +38,6 @@ py_library(
|
|||||||
"training_eager_v1.py",
|
"training_eager_v1.py",
|
||||||
"training_generator_v1.py",
|
"training_generator_v1.py",
|
||||||
"training_utils.py",
|
"training_utils.py",
|
||||||
"training_utils_v1.py",
|
|
||||||
"training_v1.py",
|
"training_v1.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
@ -542,9 +541,9 @@ tf_py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "training_utils_v1_test",
|
name = "training_utils_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["training_utils_v1_test.py"],
|
srcs = ["training_utils_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_oss", # TODO(b/135021748) reenable
|
"no_oss", # TODO(b/135021748) reenable
|
||||||
|
@ -30,7 +30,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import callbacks as cbks
|
from tensorflow.python.keras import callbacks as cbks
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils.generic_utils import make_batches
|
from tensorflow.python.keras.utils.generic_utils import make_batches
|
||||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
@ -139,7 +139,7 @@ def model_iteration(model,
|
|||||||
if is_dataset:
|
if is_dataset:
|
||||||
if steps_per_epoch is None:
|
if steps_per_epoch is None:
|
||||||
reset_dataset_after_each_epoch = True
|
reset_dataset_after_each_epoch = True
|
||||||
steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
|
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||||
model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
|
model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
|
||||||
input_iterator = _get_iterator(inputs, model._distribution_strategy)
|
input_iterator = _get_iterator(inputs, model._distribution_strategy)
|
||||||
|
|
||||||
@ -154,7 +154,7 @@ def model_iteration(model,
|
|||||||
do_validation = val_inputs is not None
|
do_validation = val_inputs is not None
|
||||||
|
|
||||||
# Convert Eager Tensors to NumPy arrays to support batching/shuffling.
|
# Convert Eager Tensors to NumPy arrays to support batching/shuffling.
|
||||||
inputs, targets, sample_weights = training_utils_v1. \
|
inputs, targets, sample_weights = training_utils. \
|
||||||
convert_eager_tensors_to_numpy((inputs, targets, sample_weights))
|
convert_eager_tensors_to_numpy((inputs, targets, sample_weights))
|
||||||
|
|
||||||
# Prepare input data.
|
# Prepare input data.
|
||||||
@ -197,7 +197,7 @@ def model_iteration(model,
|
|||||||
# model_iteration() call, it will not trigger the dataset-input path
|
# model_iteration() call, it will not trigger the dataset-input path
|
||||||
# that determines the number of steps required. To avoid this issue,
|
# that determines the number of steps required. To avoid this issue,
|
||||||
# set validation_steps here if validation_steps is None.
|
# set validation_steps here if validation_steps is None.
|
||||||
validation_steps = training_utils_v1.infer_steps_for_dataset(
|
validation_steps = training_utils.infer_steps_for_dataset(
|
||||||
model,
|
model,
|
||||||
val_inputs,
|
val_inputs,
|
||||||
validation_steps,
|
validation_steps,
|
||||||
@ -240,12 +240,12 @@ def model_iteration(model,
|
|||||||
|
|
||||||
# Select aggregation method.
|
# Select aggregation method.
|
||||||
if mode == ModeKeys.PREDICT:
|
if mode == ModeKeys.PREDICT:
|
||||||
aggregator = training_utils_v1.OutputsAggregator(
|
aggregator = training_utils.OutputsAggregator(
|
||||||
use_steps,
|
use_steps,
|
||||||
num_samples=None if steps_per_epoch else num_samples_or_steps,
|
num_samples=None if steps_per_epoch else num_samples_or_steps,
|
||||||
steps=steps_per_epoch)
|
steps=steps_per_epoch)
|
||||||
else:
|
else:
|
||||||
aggregator = training_utils_v1.MetricsAggregator(
|
aggregator = training_utils.MetricsAggregator(
|
||||||
use_steps,
|
use_steps,
|
||||||
num_samples=None if steps_per_epoch else num_samples_or_steps,
|
num_samples=None if steps_per_epoch else num_samples_or_steps,
|
||||||
steps=steps_per_epoch)
|
steps=steps_per_epoch)
|
||||||
@ -350,7 +350,7 @@ def model_iteration(model,
|
|||||||
# Sample-wise loop.
|
# Sample-wise loop.
|
||||||
index_array = np.arange(num_samples_or_steps)
|
index_array = np.arange(num_samples_or_steps)
|
||||||
if shuffle == 'batch':
|
if shuffle == 'batch':
|
||||||
index_array = training_utils_v1.batch_shuffle(index_array, batch_size)
|
index_array = training_utils.batch_shuffle(index_array, batch_size)
|
||||||
elif shuffle:
|
elif shuffle:
|
||||||
np.random.shuffle(index_array)
|
np.random.shuffle(index_array)
|
||||||
batches = make_batches(num_samples_or_steps, batch_size)
|
batches = make_batches(num_samples_or_steps, batch_size)
|
||||||
@ -409,7 +409,7 @@ def model_iteration(model,
|
|||||||
|
|
||||||
# Run the test loop every `validation_freq` epochs during training.
|
# Run the test loop every `validation_freq` epochs during training.
|
||||||
if (do_validation and
|
if (do_validation and
|
||||||
training_utils_v1.should_run_validation(validation_freq, epoch) and
|
training_utils.should_run_validation(validation_freq, epoch) and
|
||||||
not callbacks.model.stop_training):
|
not callbacks.model.stop_training):
|
||||||
|
|
||||||
if model._compile_distribution:
|
if model._compile_distribution:
|
||||||
@ -483,8 +483,8 @@ def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
|
|||||||
"""Returns total number of samples (when training in batch mode) or steps."""
|
"""Returns total number of samples (when training in batch mode) or steps."""
|
||||||
if steps_per_epoch:
|
if steps_per_epoch:
|
||||||
return steps_per_epoch
|
return steps_per_epoch
|
||||||
return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch,
|
return training_utils.check_num_samples(ins, batch_size, steps_per_epoch,
|
||||||
'steps_per_epoch')
|
'steps_per_epoch')
|
||||||
|
|
||||||
|
|
||||||
def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
||||||
@ -527,7 +527,7 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
inputs,
|
inputs,
|
||||||
extract_tensors_from_dataset=True)
|
extract_tensors_from_dataset=True)
|
||||||
|
|
||||||
inputs = training_utils_v1.ModelInputs(inputs).as_list()
|
inputs = training_utils.ModelInputs(inputs).as_list()
|
||||||
targets = list(targets or [])
|
targets = list(targets or [])
|
||||||
sample_weights = list(sample_weights or [])
|
sample_weights = list(sample_weights or [])
|
||||||
ins = inputs + targets + sample_weights
|
ins = inputs + targets + sample_weights
|
||||||
@ -541,7 +541,7 @@ def _get_iterator(inputs, distribution_strategy=None):
|
|||||||
if distribution_strategy:
|
if distribution_strategy:
|
||||||
return distributed_training_utils_v1.get_iterator(
|
return distributed_training_utils_v1.get_iterator(
|
||||||
inputs, distribution_strategy)
|
inputs, distribution_strategy)
|
||||||
return training_utils_v1.get_iterator(inputs)
|
return training_utils.get_iterator(inputs)
|
||||||
|
|
||||||
|
|
||||||
def _reinitialize_iterator(iterator, distribution_strategy=None):
|
def _reinitialize_iterator(iterator, distribution_strategy=None):
|
||||||
@ -549,7 +549,7 @@ def _reinitialize_iterator(iterator, distribution_strategy=None):
|
|||||||
distributed_training_utils_v1.initialize_iterator(
|
distributed_training_utils_v1.initialize_iterator(
|
||||||
iterator, distribution_strategy)
|
iterator, distribution_strategy)
|
||||||
else:
|
else:
|
||||||
training_utils_v1.initialize_iterator(iterator)
|
training_utils.initialize_iterator(iterator)
|
||||||
|
|
||||||
|
|
||||||
def _make_execution_function(model, mode):
|
def _make_execution_function(model, mode):
|
||||||
@ -593,7 +593,7 @@ predict_loop = functools.partial(
|
|||||||
model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
|
model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
|
class ArrayLikeTrainingLoop(training_utils.TrainingLoop):
|
||||||
"""TrainingLoop that handle inputs like array.
|
"""TrainingLoop that handle inputs like array.
|
||||||
|
|
||||||
This is the default handler for most of the input data types, includes
|
This is the default handler for most of the input data types, includes
|
||||||
@ -639,9 +639,9 @@ class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
val_x, val_y, val_sample_weights = model._prepare_validation_data(
|
val_x, val_y, val_sample_weights = model._prepare_validation_data(
|
||||||
validation_data, batch_size, validation_steps)
|
validation_data, batch_size, validation_steps)
|
||||||
elif validation_split and 0. < validation_split < 1.:
|
elif validation_split and 0. < validation_split < 1.:
|
||||||
(x, y, sample_weights, val_x, val_y, val_sample_weights
|
(x, y, sample_weights, val_x, val_y,
|
||||||
) = training_utils_v1.split_training_and_validation_data(
|
val_sample_weights) = training_utils.split_training_and_validation_data(
|
||||||
x, y, sample_weights, validation_split)
|
x, y, sample_weights, validation_split)
|
||||||
else:
|
else:
|
||||||
if validation_steps:
|
if validation_steps:
|
||||||
raise ValueError('`validation_steps` should not be specified if '
|
raise ValueError('`validation_steps` should not be specified if '
|
||||||
|
@ -35,7 +35,7 @@ from tensorflow.python.keras.distribute import distributed_training_utils as dis
|
|||||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
|
||||||
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
|
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
|
||||||
from tensorflow.python.keras.engine import training_arrays_v1
|
from tensorflow.python.keras.engine import training_arrays_v1
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -258,7 +258,7 @@ def experimental_tpu_fit_loop(model,
|
|||||||
break
|
break
|
||||||
|
|
||||||
if (do_validation and
|
if (do_validation and
|
||||||
training_utils_v1.should_run_validation(validation_freq, epoch)):
|
training_utils.should_run_validation(validation_freq, epoch)):
|
||||||
logging.info('Running validation at fit epoch: %s', epoch)
|
logging.info('Running validation at fit epoch: %s', epoch)
|
||||||
|
|
||||||
if model._compile_distribution:
|
if model._compile_distribution:
|
||||||
@ -575,7 +575,7 @@ def experimental_tpu_predict_loop(model,
|
|||||||
return prediction_result
|
return prediction_result
|
||||||
|
|
||||||
|
|
||||||
class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
|
class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||||
"""Training loop for distribution strategy with single worker."""
|
"""Training loop for distribution strategy with single worker."""
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
@ -630,8 +630,8 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
|
|
||||||
val_dataset = None
|
val_dataset = None
|
||||||
if validation_data:
|
if validation_data:
|
||||||
val_x, val_y, val_sample_weights = (
|
val_x, val_y, val_sample_weights = training_utils.unpack_validation_data(
|
||||||
training_utils_v1.unpack_validation_data(validation_data))
|
validation_data)
|
||||||
dist_utils.validate_inputs(val_x, val_y)
|
dist_utils.validate_inputs(val_x, val_y)
|
||||||
_, validation_steps = dist_utils.process_batch_and_step_size(
|
_, validation_steps = dist_utils.process_batch_and_step_size(
|
||||||
model._distribution_strategy, val_x, batch_size, validation_steps,
|
model._distribution_strategy, val_x, batch_size, validation_steps,
|
||||||
@ -650,7 +650,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
'distribution strategies.')
|
'distribution strategies.')
|
||||||
|
|
||||||
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||||
steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
|
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
||||||
if steps_per_epoch is None:
|
if steps_per_epoch is None:
|
||||||
raise ValueError('Number of steps could not be inferred from the data, '
|
raise ValueError('Number of steps could not be inferred from the data, '
|
||||||
@ -707,7 +707,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
allow_partial_batch=True)
|
allow_partial_batch=True)
|
||||||
|
|
||||||
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||||
steps = training_utils_v1.infer_steps_for_dataset(
|
steps = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, steps, steps_name='steps')
|
model, dataset, steps, steps_name='steps')
|
||||||
if steps is None:
|
if steps is None:
|
||||||
raise ValueError('Number of steps could not be inferred from the data, '
|
raise ValueError('Number of steps could not be inferred from the data, '
|
||||||
@ -744,7 +744,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
allow_partial_batch=True)
|
allow_partial_batch=True)
|
||||||
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||||
steps = training_utils_v1.infer_steps_for_dataset(
|
steps = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, steps, steps_name='steps')
|
model, dataset, steps, steps_name='steps')
|
||||||
if steps is None:
|
if steps is None:
|
||||||
raise ValueError('Number of steps could not be inferred from the data, '
|
raise ValueError('Number of steps could not be inferred from the data, '
|
||||||
@ -780,7 +780,7 @@ def _train_with_multi_worker(method):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
|
class DistributionMultiWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||||
"""Training loop for distribution strategy with multiple worker."""
|
"""Training loop for distribution strategy with multiple worker."""
|
||||||
|
|
||||||
def __init__(self, single_worker_loop):
|
def __init__(self, single_worker_loop):
|
||||||
|
@ -25,7 +25,6 @@ from tensorflow.python.eager.backprop import GradientTape
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||||
from tensorflow.python.keras.utils import losses_utils
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -128,12 +127,11 @@ def _model_loss(model,
|
|||||||
outs = nest.flatten(outs)
|
outs = nest.flatten(outs)
|
||||||
|
|
||||||
if targets:
|
if targets:
|
||||||
targets = training_utils_v1.cast_if_floating_dtype_and_mismatch(
|
targets = training_utils.cast_if_floating_dtype_and_mismatch(targets, outs)
|
||||||
targets, outs)
|
|
||||||
# TODO(sallymatson/psv): check if we should do same mismatch fix for weights
|
# TODO(sallymatson/psv): check if we should do same mismatch fix for weights
|
||||||
if sample_weights:
|
if sample_weights:
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils_v1.cast_if_floating_dtype(
|
training_utils.cast_if_floating_dtype(
|
||||||
ops.convert_to_tensor_v2_with_dispatch(val))
|
ops.convert_to_tensor_v2_with_dispatch(val))
|
||||||
if val is not None else None for val in sample_weights
|
if val is not None else None for val in sample_weights
|
||||||
]
|
]
|
||||||
@ -306,7 +304,7 @@ def train_on_batch(model,
|
|||||||
model output. Could be a empty list when model has only one output.
|
model output. Could be a empty list when model has only one output.
|
||||||
'metrics': list of tensors for metric specified.
|
'metrics': list of tensors for metric specified.
|
||||||
"""
|
"""
|
||||||
inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
|
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
||||||
outs, total_loss, output_losses, masks = (
|
outs, total_loss, output_losses, masks = (
|
||||||
_process_single_batch(
|
_process_single_batch(
|
||||||
model,
|
model,
|
||||||
@ -347,7 +345,7 @@ def test_on_batch(model,
|
|||||||
model output. Could be a empty list when model has only one output.
|
model output. Could be a empty list when model has only one output.
|
||||||
'metrics': list of tensors for metric specified.
|
'metrics': list of tensors for metric specified.
|
||||||
"""
|
"""
|
||||||
inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
|
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
||||||
|
|
||||||
with backend.eager_learning_phase_scope(0):
|
with backend.eager_learning_phase_scope(0):
|
||||||
outs, total_loss, output_losses, masks = (
|
outs, total_loss, output_losses, masks = (
|
||||||
|
@ -31,7 +31,6 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras import callbacks as cbks
|
from tensorflow.python.keras import callbacks as cbks
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
|
||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
@ -133,7 +132,7 @@ def model_iteration(model,
|
|||||||
original_dataset = data
|
original_dataset = data
|
||||||
if steps_per_epoch is None:
|
if steps_per_epoch is None:
|
||||||
reset_dataset_after_each_epoch = True
|
reset_dataset_after_each_epoch = True
|
||||||
steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
|
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||||
model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
|
model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
|
||||||
|
|
||||||
# Convert to a format that supports `next(generator)`.
|
# Convert to a format that supports `next(generator)`.
|
||||||
@ -180,11 +179,9 @@ def model_iteration(model,
|
|||||||
mode=mode)
|
mode=mode)
|
||||||
|
|
||||||
if mode == ModeKeys.PREDICT:
|
if mode == ModeKeys.PREDICT:
|
||||||
aggregator = training_utils_v1.OutputsAggregator(
|
aggregator = training_utils.OutputsAggregator(True, steps=steps_per_epoch)
|
||||||
True, steps=steps_per_epoch)
|
|
||||||
else:
|
else:
|
||||||
aggregator = training_utils_v1.MetricsAggregator(
|
aggregator = training_utils.MetricsAggregator(True, steps=steps_per_epoch)
|
||||||
True, steps=steps_per_epoch)
|
|
||||||
|
|
||||||
should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
|
should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
|
||||||
if should_set_learning_phase:
|
if should_set_learning_phase:
|
||||||
@ -296,7 +293,7 @@ def model_iteration(model,
|
|||||||
|
|
||||||
# Run the test loop every epoch during training.
|
# Run the test loop every epoch during training.
|
||||||
if (do_validation and
|
if (do_validation and
|
||||||
training_utils_v1.should_run_validation(validation_freq, epoch) and
|
training_utils.should_run_validation(validation_freq, epoch) and
|
||||||
not callbacks.model.stop_training):
|
not callbacks.model.stop_training):
|
||||||
val_results = model_iteration(
|
val_results = model_iteration(
|
||||||
model,
|
model,
|
||||||
@ -541,7 +538,7 @@ def _get_num_samples_or_steps(data, steps_per_epoch):
|
|||||||
return steps_per_epoch, True
|
return steps_per_epoch, True
|
||||||
|
|
||||||
|
|
||||||
class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
|
class GeneratorOrSequenceTrainingLoop(training_utils.TrainingLoop):
|
||||||
"""Generator-like.
|
"""Generator-like.
|
||||||
|
|
||||||
Input is Python generator, or Sequence object.
|
Input is Python generator, or Sequence object.
|
||||||
@ -572,7 +569,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
workers=1,
|
workers=1,
|
||||||
use_multiprocessing=False):
|
use_multiprocessing=False):
|
||||||
model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
|
model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
|
||||||
training_utils_v1.check_generator_arguments(
|
training_utils.check_generator_arguments(
|
||||||
y, sample_weight, validation_split=validation_split)
|
y, sample_weight, validation_split=validation_split)
|
||||||
return fit_generator(
|
return fit_generator(
|
||||||
model,
|
model,
|
||||||
@ -605,7 +602,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
workers=1,
|
workers=1,
|
||||||
use_multiprocessing=False):
|
use_multiprocessing=False):
|
||||||
model._validate_or_infer_batch_size(batch_size, steps, x)
|
model._validate_or_infer_batch_size(batch_size, steps, x)
|
||||||
training_utils_v1.check_generator_arguments(y, sample_weight)
|
training_utils.check_generator_arguments(y, sample_weight)
|
||||||
return evaluate_generator(
|
return evaluate_generator(
|
||||||
model,
|
model,
|
||||||
x,
|
x,
|
||||||
@ -638,7 +635,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
use_multiprocessing=use_multiprocessing)
|
use_multiprocessing=use_multiprocessing)
|
||||||
|
|
||||||
|
|
||||||
class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
|
class EagerDatasetOrIteratorTrainingLoop(training_utils.TrainingLoop):
|
||||||
"""A non-distributed Dataset or iterator in eager execution."""
|
"""A non-distributed Dataset or iterator in eager execution."""
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
@ -661,11 +658,10 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
|
model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
|
||||||
# Make sure that y, sample_weights, validation_split are not passed.
|
# Make sure that y, sample_weights, validation_split are not passed.
|
||||||
training_utils_v1.validate_dataset_input(x, y, sample_weight,
|
training_utils.validate_dataset_input(x, y, sample_weight, validation_split)
|
||||||
validation_split)
|
|
||||||
if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and
|
if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and
|
||||||
shuffle):
|
shuffle):
|
||||||
training_utils_v1.verify_dataset_shuffled(x)
|
training_utils.verify_dataset_shuffled(x)
|
||||||
|
|
||||||
return fit_generator(
|
return fit_generator(
|
||||||
model,
|
model,
|
||||||
@ -695,7 +691,7 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
model._validate_or_infer_batch_size(batch_size, steps, x)
|
model._validate_or_infer_batch_size(batch_size, steps, x)
|
||||||
# Make sure that y, sample_weights, validation_split are not passed.
|
# Make sure that y, sample_weights, validation_split are not passed.
|
||||||
training_utils_v1.validate_dataset_input(x, y, sample_weight)
|
training_utils.validate_dataset_input(x, y, sample_weight)
|
||||||
return evaluate_generator(
|
return evaluate_generator(
|
||||||
model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
|
model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
|
||||||
|
|
||||||
@ -712,7 +708,7 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
|
model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
|
class GeneratorLikeTrainingLoop(training_utils.TrainingLoop):
|
||||||
"""TrainingLoop that handle inputs like python generator.
|
"""TrainingLoop that handle inputs like python generator.
|
||||||
|
|
||||||
This is the default handler for most of the input data types, includes
|
This is the default handler for most of the input data types, includes
|
||||||
@ -759,9 +755,8 @@ class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
|
|||||||
validation_steps)
|
validation_steps)
|
||||||
elif validation_split and 0. < validation_split < 1.:
|
elif validation_split and 0. < validation_split < 1.:
|
||||||
(x, y, sample_weights, val_x, val_y,
|
(x, y, sample_weights, val_x, val_y,
|
||||||
val_sample_weights) = (
|
val_sample_weights) = training_utils.split_training_and_validation_data(
|
||||||
training_utils_v1.split_training_and_validation_data(
|
x, y, sample_weights, validation_split)
|
||||||
x, y, sample_weights, validation_split))
|
|
||||||
validation_data = (val_x, val_y, val_sample_weights)
|
validation_data = (val_x, val_y, val_sample_weights)
|
||||||
else:
|
else:
|
||||||
if validation_steps:
|
if validation_steps:
|
||||||
|
@ -44,7 +44,7 @@ from tensorflow.python.keras.callbacks import Callback
|
|||||||
from tensorflow.python.keras.engine import input_layer
|
from tensorflow.python.keras.engine import input_layer
|
||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.engine import training as training_module
|
from tensorflow.python.keras.engine import training as training_module
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
from tensorflow.python.keras.utils import np_utils
|
from tensorflow.python.keras.utils import np_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -2019,7 +2019,7 @@ class LossWeightingTest(keras_parameterized.TestCase):
|
|||||||
[[0, .4, 1, 1], [2, .4, .3, 1]])
|
[[0, .4, 1, 1], [2, .4, .3, 1]])
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(sample_weights)
|
dataset = dataset_ops.Dataset.from_tensor_slices(sample_weights)
|
||||||
sample_weights = dataset_ops.make_one_shot_iterator(dataset).get_next()
|
sample_weights = dataset_ops.make_one_shot_iterator(dataset).get_next()
|
||||||
sample_weights = training_utils_v1.standardize_sample_weights(
|
sample_weights = training_utils.standardize_sample_weights(
|
||||||
sample_weights, model.output_names)
|
sample_weights, model.output_names)
|
||||||
|
|
||||||
# Update model loss with sample weight tensor.
|
# Update model loss with sample weight tensor.
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -35,7 +35,7 @@ from tensorflow.python.keras import backend
|
|||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -45,7 +45,7 @@ class ModelInputsTest(test.TestCase):
|
|||||||
|
|
||||||
def test_single_thing(self):
|
def test_single_thing(self):
|
||||||
a = np.ones(10)
|
a = np.ones(10)
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['input_1'], model_inputs.get_input_names())
|
self.assertEqual(['input_1'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertTrue(tensor_util.is_tensor(vals))
|
self.assertTrue(tensor_util.is_tensor(vals))
|
||||||
@ -59,7 +59,7 @@ class ModelInputsTest(test.TestCase):
|
|||||||
self.skipTest('Run in eager mode only.')
|
self.skipTest('Run in eager mode only.')
|
||||||
with testing_utils.use_keras_tensors_scope(False):
|
with testing_utils.use_keras_tensors_scope(False):
|
||||||
a = np.ones(10, dtype=np.int32)
|
a = np.ones(10, dtype=np.int32)
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['input_1'], model_inputs.get_input_names())
|
self.assertEqual(['input_1'], model_inputs.get_input_names())
|
||||||
val = model_inputs.get_symbolic_inputs()
|
val = model_inputs.get_symbolic_inputs()
|
||||||
self.assertTrue(tf_utils.is_symbolic_tensor(val))
|
self.assertTrue(tf_utils.is_symbolic_tensor(val))
|
||||||
@ -69,7 +69,7 @@ class ModelInputsTest(test.TestCase):
|
|||||||
self.assertEqual(dtypes.int32, vals[0].dtype)
|
self.assertEqual(dtypes.int32, vals[0].dtype)
|
||||||
with testing_utils.use_keras_tensors_scope(True):
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
a = np.ones(10, dtype=np.int32)
|
a = np.ones(10, dtype=np.int32)
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['input_1'], model_inputs.get_input_names())
|
self.assertEqual(['input_1'], model_inputs.get_input_names())
|
||||||
val = model_inputs.get_symbolic_inputs()
|
val = model_inputs.get_symbolic_inputs()
|
||||||
self.assertIsInstance(val, keras_tensor.KerasTensor)
|
self.assertIsInstance(val, keras_tensor.KerasTensor)
|
||||||
@ -80,7 +80,7 @@ class ModelInputsTest(test.TestCase):
|
|||||||
|
|
||||||
def test_list(self):
|
def test_list(self):
|
||||||
a = [np.ones(10), np.ones(20)]
|
a = [np.ones(10), np.ones(20)]
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
|
self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertTrue(tensor_util.is_tensor(vals[0]))
|
self.assertTrue(tensor_util.is_tensor(vals[0]))
|
||||||
@ -91,14 +91,14 @@ class ModelInputsTest(test.TestCase):
|
|||||||
self.skipTest('Run in eager mode only.')
|
self.skipTest('Run in eager mode only.')
|
||||||
with testing_utils.use_keras_tensors_scope(False):
|
with testing_utils.use_keras_tensors_scope(False):
|
||||||
a = [np.ones(10), np.ones(20)]
|
a = [np.ones(10), np.ones(20)]
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
|
self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertTrue(tf_utils.is_symbolic_tensor(vals[0]))
|
self.assertTrue(tf_utils.is_symbolic_tensor(vals[0]))
|
||||||
self.assertTrue(tf_utils.is_symbolic_tensor(vals[1]))
|
self.assertTrue(tf_utils.is_symbolic_tensor(vals[1]))
|
||||||
with testing_utils.use_keras_tensors_scope(True):
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
a = [np.ones(10), np.ones(20)]
|
a = [np.ones(10), np.ones(20)]
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
|
self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertIsInstance(vals[0], keras_tensor.KerasTensor)
|
self.assertIsInstance(vals[0], keras_tensor.KerasTensor)
|
||||||
@ -106,7 +106,7 @@ class ModelInputsTest(test.TestCase):
|
|||||||
|
|
||||||
def test_dict(self):
|
def test_dict(self):
|
||||||
a = {'b': np.ones(10), 'a': np.ones(20)}
|
a = {'b': np.ones(10), 'a': np.ones(20)}
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['a', 'b'], model_inputs.get_input_names())
|
self.assertEqual(['a', 'b'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertTrue(tensor_util.is_tensor(vals['a']))
|
self.assertTrue(tensor_util.is_tensor(vals['a']))
|
||||||
@ -117,14 +117,14 @@ class ModelInputsTest(test.TestCase):
|
|||||||
self.skipTest('Run in eager mode only.')
|
self.skipTest('Run in eager mode only.')
|
||||||
with testing_utils.use_keras_tensors_scope(False):
|
with testing_utils.use_keras_tensors_scope(False):
|
||||||
a = {'b': np.ones(10), 'a': np.ones(20)}
|
a = {'b': np.ones(10), 'a': np.ones(20)}
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['a', 'b'], model_inputs.get_input_names())
|
self.assertEqual(['a', 'b'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertTrue(tf_utils.is_symbolic_tensor(vals['a']))
|
self.assertTrue(tf_utils.is_symbolic_tensor(vals['a']))
|
||||||
self.assertTrue(tf_utils.is_symbolic_tensor(vals['b']))
|
self.assertTrue(tf_utils.is_symbolic_tensor(vals['b']))
|
||||||
with testing_utils.use_keras_tensors_scope(True):
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
a = {'b': np.ones(10), 'a': np.ones(20)}
|
a = {'b': np.ones(10), 'a': np.ones(20)}
|
||||||
model_inputs = training_utils_v1.ModelInputs(a)
|
model_inputs = training_utils.ModelInputs(a)
|
||||||
self.assertEqual(['a', 'b'], model_inputs.get_input_names())
|
self.assertEqual(['a', 'b'], model_inputs.get_input_names())
|
||||||
vals = model_inputs.get_symbolic_inputs()
|
vals = model_inputs.get_symbolic_inputs()
|
||||||
self.assertIsInstance(vals['a'], keras_tensor.KerasTensor)
|
self.assertIsInstance(vals['a'], keras_tensor.KerasTensor)
|
||||||
@ -182,12 +182,12 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
if not expect_shuffled:
|
if not expect_shuffled:
|
||||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||||
shuffled = training_utils_v1.verify_dataset_shuffled(dataset)
|
shuffled = training_utils.verify_dataset_shuffled(dataset)
|
||||||
self.assertRegex(
|
self.assertRegex(
|
||||||
str(mock_log.call_args), 'input dataset `x` is not shuffled.')
|
str(mock_log.call_args), 'input dataset `x` is not shuffled.')
|
||||||
self.assertFalse(shuffled)
|
self.assertFalse(shuffled)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
|
self.assertTrue(training_utils.verify_dataset_shuffled(dataset))
|
||||||
|
|
||||||
|
|
||||||
class StandardizeWeightsTest(keras_parameterized.TestCase):
|
class StandardizeWeightsTest(keras_parameterized.TestCase):
|
||||||
@ -195,22 +195,21 @@ class StandardizeWeightsTest(keras_parameterized.TestCase):
|
|||||||
def test_sample_weights(self):
|
def test_sample_weights(self):
|
||||||
y = np.array([0, 1, 0, 0, 2])
|
y = np.array([0, 1, 0, 0, 2])
|
||||||
sample_weights = np.array([0.5, 1., 1., 0., 2.])
|
sample_weights = np.array([0.5, 1., 1., 0., 2.])
|
||||||
weights = training_utils_v1.standardize_weights(y, sample_weights)
|
weights = training_utils.standardize_weights(y, sample_weights)
|
||||||
self.assertAllClose(weights, sample_weights)
|
self.assertAllClose(weights, sample_weights)
|
||||||
|
|
||||||
def test_class_weights(self):
|
def test_class_weights(self):
|
||||||
y = np.array([0, 1, 0, 0, 2])
|
y = np.array([0, 1, 0, 0, 2])
|
||||||
class_weights = {0: 0.5, 1: 1., 2: 1.5}
|
class_weights = {0: 0.5, 1: 1., 2: 1.5}
|
||||||
weights = training_utils_v1.standardize_weights(
|
weights = training_utils.standardize_weights(y, class_weight=class_weights)
|
||||||
y, class_weight=class_weights)
|
|
||||||
self.assertAllClose(weights, np.array([0.5, 1., 0.5, 0.5, 1.5]))
|
self.assertAllClose(weights, np.array([0.5, 1., 0.5, 0.5, 1.5]))
|
||||||
|
|
||||||
def test_sample_weights_and_class_weights(self):
|
def test_sample_weights_and_class_weights(self):
|
||||||
y = np.array([0, 1, 0, 0, 2])
|
y = np.array([0, 1, 0, 0, 2])
|
||||||
sample_weights = np.array([0.5, 1., 1., 0., 2.])
|
sample_weights = np.array([0.5, 1., 1., 0., 2.])
|
||||||
class_weights = {0: 0.5, 1: 1., 2: 1.5}
|
class_weights = {0: 0.5, 1: 1., 2: 1.5}
|
||||||
weights = training_utils_v1.standardize_weights(y, sample_weights,
|
weights = training_utils.standardize_weights(y, sample_weights,
|
||||||
class_weights)
|
class_weights)
|
||||||
expected = sample_weights * np.array([0.5, 1., 0.5, 0.5, 1.5])
|
expected = sample_weights * np.array([0.5, 1., 0.5, 0.5, 1.5])
|
||||||
self.assertAllClose(weights, expected)
|
self.assertAllClose(weights, expected)
|
||||||
|
|
||||||
@ -277,35 +276,32 @@ class AggregationTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(AggregationTest, self).setUp()
|
super(AggregationTest, self).setUp()
|
||||||
self._old_pool = training_utils_v1._COPY_POOL
|
self._old_pool = training_utils._COPY_POOL
|
||||||
self._old_threshold = (
|
self._old_threshold = training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD
|
||||||
training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD)
|
self._old_timeout = training_utils.SliceAggregator._MAX_COPY_SECONDS
|
||||||
self._old_timeout = training_utils_v1.SliceAggregator._MAX_COPY_SECONDS
|
training_utils._COPY_POOL = MonitoredPool(training_utils._COPY_THREADS)
|
||||||
training_utils_v1._COPY_POOL = MonitoredPool(
|
|
||||||
training_utils_v1._COPY_THREADS)
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super(AggregationTest, self).tearDown()
|
super(AggregationTest, self).tearDown()
|
||||||
training_utils_v1._COPY_POOL = self._old_pool
|
training_utils._COPY_POOL = self._old_pool
|
||||||
training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = (
|
training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = self._old_threshold
|
||||||
self._old_threshold)
|
training_utils.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout
|
||||||
training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout
|
|
||||||
|
|
||||||
def _run_with_steps(self):
|
def _run_with_steps(self):
|
||||||
aggregator = training_utils_v1.OutputsAggregator(use_steps=True)
|
aggregator = training_utils.OutputsAggregator(use_steps=True)
|
||||||
for i, batch in enumerate(np.array_split(_TEST_DATA, 4)):
|
for i, batch in enumerate(np.array_split(_TEST_DATA, 4)):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
aggregator.create(batch)
|
aggregator.create(batch)
|
||||||
aggregator.aggregate(batch)
|
aggregator.aggregate(batch)
|
||||||
|
|
||||||
assert len(aggregator.results) == 1
|
assert len(aggregator.results) == 1
|
||||||
assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator)
|
assert isinstance(aggregator.results[0], training_utils.ConcatAggregator)
|
||||||
|
|
||||||
aggregator.finalize()
|
aggregator.finalize()
|
||||||
return aggregator.results
|
return aggregator.results
|
||||||
|
|
||||||
def _run_without_steps(self):
|
def _run_without_steps(self):
|
||||||
aggregator = training_utils_v1.OutputsAggregator(
|
aggregator = training_utils.OutputsAggregator(
|
||||||
use_steps=False, num_samples=6)
|
use_steps=False, num_samples=6)
|
||||||
|
|
||||||
batch_start = 0
|
batch_start = 0
|
||||||
@ -318,7 +314,7 @@ class AggregationTest(keras_parameterized.TestCase):
|
|||||||
batch_start = batch_end
|
batch_start = batch_end
|
||||||
|
|
||||||
assert len(aggregator.results) == 1
|
assert len(aggregator.results) == 1
|
||||||
assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator)
|
assert isinstance(aggregator.results[0], training_utils.SliceAggregator)
|
||||||
|
|
||||||
aggregator.finalize()
|
aggregator.finalize()
|
||||||
return aggregator.results
|
return aggregator.results
|
||||||
@ -330,7 +326,7 @@ class AggregationTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
|
self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
|
||||||
|
|
||||||
def test_nested_aggregation(self):
|
def test_nested_aggregation(self):
|
||||||
aggregator = training_utils_v1.OutputsAggregator(
|
aggregator = training_utils.OutputsAggregator(
|
||||||
use_steps=False, num_samples=6)
|
use_steps=False, num_samples=6)
|
||||||
|
|
||||||
batches = np.array_split(_TEST_DATA, 4)
|
batches = np.array_split(_TEST_DATA, 4)
|
||||||
@ -348,46 +344,46 @@ class AggregationTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(aggregator.results, (_TEST_DATA, _TEST_DATA))
|
self.assertAllEqual(aggregator.results, (_TEST_DATA, _TEST_DATA))
|
||||||
|
|
||||||
def test_concat_single_batch(self):
|
def test_concat_single_batch(self):
|
||||||
aggregator = training_utils_v1.OutputsAggregator(use_steps=True)
|
aggregator = training_utils.OutputsAggregator(use_steps=True)
|
||||||
data = _TEST_DATA.copy()
|
data = _TEST_DATA.copy()
|
||||||
aggregator.create(data)
|
aggregator.create(data)
|
||||||
assert len(aggregator.results) == 1
|
assert len(aggregator.results) == 1
|
||||||
assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator)
|
assert isinstance(aggregator.results[0], training_utils.ConcatAggregator)
|
||||||
|
|
||||||
aggregator.aggregate(data)
|
aggregator.aggregate(data)
|
||||||
aggregator.finalize()
|
aggregator.finalize()
|
||||||
assert aggregator.results is data # No copy.
|
assert aggregator.results is data # No copy.
|
||||||
|
|
||||||
def test_slice_single_batch(self):
|
def test_slice_single_batch(self):
|
||||||
aggregator = training_utils_v1.OutputsAggregator(
|
aggregator = training_utils.OutputsAggregator(
|
||||||
use_steps=False, num_samples=6)
|
use_steps=False, num_samples=6)
|
||||||
data = _TEST_DATA.copy()
|
data = _TEST_DATA.copy()
|
||||||
aggregator.create(data)
|
aggregator.create(data)
|
||||||
assert len(aggregator.results) == 1
|
assert len(aggregator.results) == 1
|
||||||
assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator)
|
assert isinstance(aggregator.results[0], training_utils.SliceAggregator)
|
||||||
|
|
||||||
aggregator.aggregate(data, 0, 6)
|
aggregator.aggregate(data, 0, 6)
|
||||||
aggregator.finalize()
|
aggregator.finalize()
|
||||||
assert aggregator.results is data # No copy.
|
assert aggregator.results is data # No copy.
|
||||||
|
|
||||||
def test_async_copy(self):
|
def test_async_copy(self):
|
||||||
training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
|
training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
|
||||||
self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
|
self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
|
||||||
|
|
||||||
# Two of the four batches will have 20 elements and two will have 10.
|
# Two of the four batches will have 20 elements and two will have 10.
|
||||||
self.assertEqual(training_utils_v1._COPY_POOL._apply_counter, 2)
|
self.assertEqual(training_utils._COPY_POOL._apply_counter, 2)
|
||||||
|
|
||||||
def test_async_copy_timeout(self):
|
def test_async_copy_timeout(self):
|
||||||
training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
|
training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
|
||||||
training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 0.1
|
training_utils.SliceAggregator._MAX_COPY_SECONDS = 0.1
|
||||||
training_utils_v1._COPY_POOL._func_wrapper = add_sleep
|
training_utils._COPY_POOL._func_wrapper = add_sleep
|
||||||
with self.assertRaisesRegex(ValueError, 'Timed out waiting for copy'):
|
with self.assertRaisesRegex(ValueError, 'Timed out waiting for copy'):
|
||||||
self._run_without_steps()
|
self._run_without_steps()
|
||||||
|
|
||||||
def test_async_copy_reraise(self):
|
def test_async_copy_reraise(self):
|
||||||
training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
|
training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
|
||||||
training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 1.
|
training_utils.SliceAggregator._MAX_COPY_SECONDS = 1.
|
||||||
training_utils_v1._COPY_POOL._func_wrapper = cause_error
|
training_utils._COPY_POOL._func_wrapper = cause_error
|
||||||
with self.assertRaisesRegex(TypeError, 'NoneType'):
|
with self.assertRaisesRegex(TypeError, 'NoneType'):
|
||||||
self._run_without_steps()
|
self._run_without_steps()
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
@ -51,7 +51,6 @@ from tensorflow.python.keras.engine import training_distributed_v1
|
|||||||
from tensorflow.python.keras.engine import training_eager_v1
|
from tensorflow.python.keras.engine import training_eager_v1
|
||||||
from tensorflow.python.keras.engine import training_generator_v1
|
from tensorflow.python.keras.engine import training_generator_v1
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.engine import training_utils_v1
|
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||||
@ -414,7 +413,7 @@ class Model(training_lib.Model):
|
|||||||
base_layer.keras_api_gauge.get_cell('compile').set(True)
|
base_layer.keras_api_gauge.get_cell('compile').set(True)
|
||||||
|
|
||||||
# Prepare list of loss functions, same size of model outputs.
|
# Prepare list of loss functions, same size of model outputs.
|
||||||
self.loss_functions = training_utils_v1.prepare_loss_functions(
|
self.loss_functions = training_utils.prepare_loss_functions(
|
||||||
self.loss, self.output_names)
|
self.loss, self.output_names)
|
||||||
|
|
||||||
target_tensors = self._process_target_tensor_for_compile(target_tensors)
|
target_tensors = self._process_target_tensor_for_compile(target_tensors)
|
||||||
@ -426,8 +425,7 @@ class Model(training_lib.Model):
|
|||||||
self._training_endpoints.append(endpoint)
|
self._training_endpoints.append(endpoint)
|
||||||
|
|
||||||
# Prepare list loss weights, same size of model outputs.
|
# Prepare list loss weights, same size of model outputs.
|
||||||
training_utils_v1.prepare_loss_weights(self._training_endpoints,
|
training_utils.prepare_loss_weights(self._training_endpoints, loss_weights)
|
||||||
loss_weights)
|
|
||||||
|
|
||||||
# Initialization for Eager mode execution.
|
# Initialization for Eager mode execution.
|
||||||
if self.run_eagerly:
|
if self.run_eagerly:
|
||||||
@ -449,7 +447,7 @@ class Model(training_lib.Model):
|
|||||||
masks=self._prepare_output_masks())
|
masks=self._prepare_output_masks())
|
||||||
|
|
||||||
# Prepare sample weight modes. List with the same length as model outputs.
|
# Prepare sample weight modes. List with the same length as model outputs.
|
||||||
training_utils_v1.prepare_sample_weight_modes(
|
training_utils.prepare_sample_weight_modes(
|
||||||
self._training_endpoints, sample_weight_mode)
|
self._training_endpoints, sample_weight_mode)
|
||||||
|
|
||||||
# Creates the model loss and weighted metrics sub-graphs.
|
# Creates the model loss and weighted metrics sub-graphs.
|
||||||
@ -595,7 +593,7 @@ class Model(training_lib.Model):
|
|||||||
# or a non-distributed Dataset or iterator in eager execution.
|
# or a non-distributed Dataset or iterator in eager execution.
|
||||||
if data_utils.is_generator_or_sequence(inputs):
|
if data_utils.is_generator_or_sequence(inputs):
|
||||||
return training_generator_v1.GeneratorOrSequenceTrainingLoop()
|
return training_generator_v1.GeneratorOrSequenceTrainingLoop()
|
||||||
if training_utils_v1.is_eager_dataset_or_iterator(inputs):
|
if training_utils.is_eager_dataset_or_iterator(inputs):
|
||||||
return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
|
return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
|
||||||
|
|
||||||
# Case 3: Symbolic tensors or Numpy array-like.
|
# Case 3: Symbolic tensors or Numpy array-like.
|
||||||
@ -1076,7 +1074,7 @@ class Model(training_lib.Model):
|
|||||||
+ output_dict['metrics'])
|
+ output_dict['metrics'])
|
||||||
outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
x = training_utils_v1.ModelInputs(x).as_list()
|
x = training_utils.ModelInputs(x).as_list()
|
||||||
ins = x + list(y or []) + list(sample_weights or [])
|
ins = x + list(y or []) + list(sample_weights or [])
|
||||||
|
|
||||||
if not isinstance(K.symbolic_learning_phase(), int):
|
if not isinstance(K.symbolic_learning_phase(), int):
|
||||||
@ -1155,7 +1153,7 @@ class Model(training_lib.Model):
|
|||||||
+ output_dict['metrics'])
|
+ output_dict['metrics'])
|
||||||
outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
x = training_utils_v1.ModelInputs(x).as_list()
|
x = training_utils.ModelInputs(x).as_list()
|
||||||
inputs = x + list(y or []) + list(sample_weights or [])
|
inputs = x + list(y or []) + list(sample_weights or [])
|
||||||
|
|
||||||
self._update_sample_weight_modes(sample_weights=sample_weights)
|
self._update_sample_weight_modes(sample_weights=sample_weights)
|
||||||
@ -1200,7 +1198,7 @@ class Model(training_lib.Model):
|
|||||||
# If `self._distribution_strategy` is True, then we are in a replica context
|
# If `self._distribution_strategy` is True, then we are in a replica context
|
||||||
# at this point.
|
# at this point.
|
||||||
if self.run_eagerly or self._distribution_strategy:
|
if self.run_eagerly or self._distribution_strategy:
|
||||||
inputs = training_utils_v1.cast_if_floating_dtype(inputs)
|
inputs = training_utils.cast_if_floating_dtype(inputs)
|
||||||
if isinstance(inputs, collections_abc.Sequence):
|
if isinstance(inputs, collections_abc.Sequence):
|
||||||
# Unwrap lists with only one input, as we do when training on batch
|
# Unwrap lists with only one input, as we do when training on batch
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
@ -1372,7 +1370,7 @@ class Model(training_lib.Model):
|
|||||||
def _prepare_validation_data(self, validation_data, batch_size,
|
def _prepare_validation_data(self, validation_data, batch_size,
|
||||||
validation_steps):
|
validation_steps):
|
||||||
"""Unpack and check the validation data."""
|
"""Unpack and check the validation data."""
|
||||||
val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data(
|
val_x, val_y, val_sample_weights = training_utils.unpack_validation_data(
|
||||||
validation_data)
|
validation_data)
|
||||||
return self._standardize_user_data(
|
return self._standardize_user_data(
|
||||||
val_x,
|
val_x,
|
||||||
@ -1451,7 +1449,7 @@ class Model(training_lib.Model):
|
|||||||
|
|
||||||
def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
|
def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
|
||||||
# Prepare sample weight modes. List with the same length as model outputs.
|
# Prepare sample weight modes. List with the same length as model outputs.
|
||||||
training_utils_v1.prepare_sample_weight_modes(
|
training_utils.prepare_sample_weight_modes(
|
||||||
self._training_endpoints, sample_weight_mode)
|
self._training_endpoints, sample_weight_mode)
|
||||||
# Prepare sample weights.
|
# Prepare sample weights.
|
||||||
self._prepare_sample_weights()
|
self._prepare_sample_weights()
|
||||||
@ -1790,10 +1788,10 @@ class Model(training_lib.Model):
|
|||||||
output_shapes.append(None)
|
output_shapes.append(None)
|
||||||
else:
|
else:
|
||||||
output_shapes.append(output.shape.as_list())
|
output_shapes.append(output.shape.as_list())
|
||||||
self._per_output_metrics = training_utils_v1.collect_per_output_metric_info(
|
self._per_output_metrics = training_utils.collect_per_output_metric_info(
|
||||||
metrics, self.output_names, output_shapes, self.loss_functions)
|
metrics, self.output_names, output_shapes, self.loss_functions)
|
||||||
self._per_output_weighted_metrics = (
|
self._per_output_weighted_metrics = (
|
||||||
training_utils_v1.collect_per_output_metric_info(
|
training_utils.collect_per_output_metric_info(
|
||||||
weighted_metrics,
|
weighted_metrics,
|
||||||
self.output_names,
|
self.output_names,
|
||||||
output_shapes,
|
output_shapes,
|
||||||
@ -1903,7 +1901,7 @@ class Model(training_lib.Model):
|
|||||||
metric_results = []
|
metric_results = []
|
||||||
for metric_name, metric_fn in metrics_dict.items():
|
for metric_name, metric_fn in metrics_dict.items():
|
||||||
with K.name_scope(metric_name):
|
with K.name_scope(metric_name):
|
||||||
metric_result = training_utils_v1.call_metric_function(
|
metric_result = training_utils.call_metric_function(
|
||||||
metric_fn, y_true, y_pred, weights=weights, mask=mask)
|
metric_fn, y_true, y_pred, weights=weights, mask=mask)
|
||||||
metric_results.append(metric_result)
|
metric_results.append(metric_result)
|
||||||
return metric_results
|
return metric_results
|
||||||
@ -2140,7 +2138,7 @@ class Model(training_lib.Model):
|
|||||||
# in the codebase.
|
# in the codebase.
|
||||||
if isinstance(x, dataset_ops.DatasetV2):
|
if isinstance(x, dataset_ops.DatasetV2):
|
||||||
if shuffle:
|
if shuffle:
|
||||||
training_utils_v1.verify_dataset_shuffled(x)
|
training_utils.verify_dataset_shuffled(x)
|
||||||
|
|
||||||
strategy = self._distribution_strategy
|
strategy = self._distribution_strategy
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
@ -2192,8 +2190,8 @@ class Model(training_lib.Model):
|
|||||||
x = ds.batch(batch_size, drop_remainder=drop_remainder)
|
x = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
else:
|
else:
|
||||||
assert isinstance(x, dataset_ops.DatasetV2)
|
assert isinstance(x, dataset_ops.DatasetV2)
|
||||||
training_utils_v1.validate_dataset_input(x, y, sample_weight,
|
training_utils.validate_dataset_input(x, y, sample_weight,
|
||||||
validation_split)
|
validation_split)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _standardize_user_data(self,
|
def _standardize_user_data(self,
|
||||||
@ -2270,28 +2268,28 @@ class Model(training_lib.Model):
|
|||||||
# Graph mode dataset. We'll pass the dataset as-is (unless
|
# Graph mode dataset. We'll pass the dataset as-is (unless
|
||||||
# `extract_tensors_from_dataset` is True, in which case we extract
|
# `extract_tensors_from_dataset` is True, in which case we extract
|
||||||
# the tensors from the dataset and we output them.
|
# the tensors from the dataset and we output them.
|
||||||
training_utils_v1.validate_dataset_input(x, y, sample_weight,
|
training_utils.validate_dataset_input(x, y, sample_weight,
|
||||||
validation_split)
|
validation_split)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
training_utils_v1.verify_dataset_shuffled(x)
|
training_utils.verify_dataset_shuffled(x)
|
||||||
|
|
||||||
is_dataset = True
|
is_dataset = True
|
||||||
if extract_tensors_from_dataset:
|
if extract_tensors_from_dataset:
|
||||||
# We do this for `train_on_batch`/etc.
|
# We do this for `train_on_batch`/etc.
|
||||||
x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x)
|
x, y, sample_weight = training_utils.extract_tensors_from_dataset(x)
|
||||||
elif isinstance(x, iterator_ops.Iterator):
|
elif isinstance(x, iterator_ops.Iterator):
|
||||||
# Graph mode iterator. We extract the symbolic tensors.
|
# Graph mode iterator. We extract the symbolic tensors.
|
||||||
training_utils_v1.validate_dataset_input(x, y, sample_weight,
|
training_utils.validate_dataset_input(x, y, sample_weight,
|
||||||
validation_split)
|
validation_split)
|
||||||
iterator = x
|
iterator = x
|
||||||
x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator)
|
x, y, sample_weight = training_utils.unpack_iterator_input(iterator)
|
||||||
is_dataset = True
|
is_dataset = True
|
||||||
else:
|
else:
|
||||||
is_dataset = False
|
is_dataset = False
|
||||||
|
|
||||||
# Validates `steps` argument based on x's type.
|
# Validates `steps` argument based on x's type.
|
||||||
if check_steps:
|
if check_steps:
|
||||||
training_utils_v1.check_steps_argument(x, steps, steps_name)
|
training_utils.check_steps_argument(x, steps, steps_name)
|
||||||
|
|
||||||
# First, we build the model on the fly if necessary.
|
# First, we build the model on the fly if necessary.
|
||||||
if not self.inputs:
|
if not self.inputs:
|
||||||
@ -2354,7 +2352,7 @@ class Model(training_lib.Model):
|
|||||||
# Standardize the inputs.
|
# Standardize the inputs.
|
||||||
if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
||||||
# TODO(fchollet): run static checks with dataset output shape(s).
|
# TODO(fchollet): run static checks with dataset output shape(s).
|
||||||
x = training_utils_v1.standardize_input_data(
|
x = training_utils.standardize_input_data(
|
||||||
x,
|
x,
|
||||||
feed_input_names,
|
feed_input_names,
|
||||||
feed_input_shapes,
|
feed_input_shapes,
|
||||||
@ -2401,8 +2399,8 @@ class Model(training_lib.Model):
|
|||||||
if y is not None:
|
if y is not None:
|
||||||
# Prepare self._sample_weight_modes. List with the same length as
|
# Prepare self._sample_weight_modes. List with the same length as
|
||||||
# model outputs.
|
# model outputs.
|
||||||
training_utils_v1.prepare_sample_weight_modes(self._training_endpoints,
|
training_utils.prepare_sample_weight_modes(self._training_endpoints,
|
||||||
self.sample_weight_mode)
|
self.sample_weight_mode)
|
||||||
feed_output_names = self._feed_output_names
|
feed_output_names = self._feed_output_names
|
||||||
feed_sample_weight_modes = self._sample_weight_modes
|
feed_sample_weight_modes = self._sample_weight_modes
|
||||||
if not self._is_graph_network:
|
if not self._is_graph_network:
|
||||||
@ -2411,7 +2409,7 @@ class Model(training_lib.Model):
|
|||||||
feed_output_shapes = self._feed_output_shapes
|
feed_output_shapes = self._feed_output_shapes
|
||||||
|
|
||||||
# Standardize the outputs.
|
# Standardize the outputs.
|
||||||
y = training_utils_v1.standardize_input_data(
|
y = training_utils.standardize_input_data(
|
||||||
y,
|
y,
|
||||||
feed_output_names,
|
feed_output_names,
|
||||||
# Don't enforce target shapes to match output shapes.
|
# Don't enforce target shapes to match output shapes.
|
||||||
@ -2422,22 +2420,22 @@ class Model(training_lib.Model):
|
|||||||
|
|
||||||
# Generate sample-wise weight values given the `sample_weight` and
|
# Generate sample-wise weight values given the `sample_weight` and
|
||||||
# `class_weight` arguments.
|
# `class_weight` arguments.
|
||||||
sample_weights = training_utils_v1.standardize_sample_weights(
|
sample_weights = training_utils.standardize_sample_weights(
|
||||||
sample_weight, feed_output_names)
|
sample_weight, feed_output_names)
|
||||||
class_weights = training_utils_v1.standardize_class_weights(
|
class_weights = training_utils.standardize_class_weights(
|
||||||
class_weight, feed_output_names)
|
class_weight, feed_output_names)
|
||||||
|
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils_v1.standardize_weights(ref, sw, cw, mode)
|
training_utils.standardize_weights(ref, sw, cw, mode)
|
||||||
for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
|
for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
|
||||||
feed_sample_weight_modes)
|
feed_sample_weight_modes)
|
||||||
]
|
]
|
||||||
# Check that all arrays have the same length.
|
# Check that all arrays have the same length.
|
||||||
if not self._distribution_strategy:
|
if not self._distribution_strategy:
|
||||||
training_utils_v1.check_array_lengths(x, y, sample_weights)
|
training_utils.check_array_lengths(x, y, sample_weights)
|
||||||
if self._is_graph_network and not run_eagerly:
|
if self._is_graph_network and not run_eagerly:
|
||||||
# Additional checks to avoid users mistakenly using improper loss fns.
|
# Additional checks to avoid users mistakenly using improper loss fns.
|
||||||
training_utils_v1.check_loss_and_target_compatibility(
|
training_utils.check_loss_and_target_compatibility(
|
||||||
y, self._feed_loss_fns, feed_output_shapes)
|
y, self._feed_loss_fns, feed_output_shapes)
|
||||||
|
|
||||||
sample_weights, _, _ = training_utils.handle_partial_sample_weights(
|
sample_weights, _, _ = training_utils.handle_partial_sample_weights(
|
||||||
@ -2472,12 +2470,11 @@ class Model(training_lib.Model):
|
|||||||
# iterator and only one batch of samples is required, we fetch the data
|
# iterator and only one batch of samples is required, we fetch the data
|
||||||
# tensors from the iterator and then standardize them.
|
# tensors from the iterator and then standardize them.
|
||||||
if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
||||||
inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset(
|
inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs)
|
||||||
inputs)
|
|
||||||
# We type-check that `inputs` and `targets` are either single arrays
|
# We type-check that `inputs` and `targets` are either single arrays
|
||||||
# or lists of arrays, and extract a flat list of inputs from the passed
|
# or lists of arrays, and extract a flat list of inputs from the passed
|
||||||
# structure.
|
# structure.
|
||||||
training_utils_v1.validate_input_types(inputs, orig_inputs)
|
training_utils.validate_input_types(inputs, orig_inputs)
|
||||||
|
|
||||||
if isinstance(inputs, (list, tuple)):
|
if isinstance(inputs, (list, tuple)):
|
||||||
processed_inputs += list(inputs)
|
processed_inputs += list(inputs)
|
||||||
@ -2512,14 +2509,14 @@ class Model(training_lib.Model):
|
|||||||
if not self.inputs:
|
if not self.inputs:
|
||||||
# For subclassed models, a robust input spec is not available so we
|
# For subclassed models, a robust input spec is not available so we
|
||||||
# must cast to the model dtype.
|
# must cast to the model dtype.
|
||||||
inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype)
|
inputs = training_utils.cast_if_floating_dtype(inputs, self.dtype)
|
||||||
|
|
||||||
def create_tensor_spec(t):
|
def create_tensor_spec(t):
|
||||||
return tensor_spec.TensorSpec(t.shape, t.dtype)
|
return tensor_spec.TensorSpec(t.shape, t.dtype)
|
||||||
|
|
||||||
cast_inputs = nest.map_structure(create_tensor_spec, inputs)
|
cast_inputs = nest.map_structure(create_tensor_spec, inputs)
|
||||||
elif training_utils_v1.has_tensors(inputs):
|
elif training_utils.has_tensors(inputs):
|
||||||
cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs)
|
cast_inputs = training_utils.cast_if_floating_dtype(inputs)
|
||||||
else:
|
else:
|
||||||
cast_inputs = inputs
|
cast_inputs = inputs
|
||||||
self._set_inputs(cast_inputs)
|
self._set_inputs(cast_inputs)
|
||||||
@ -2528,11 +2525,11 @@ class Model(training_lib.Model):
|
|||||||
def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
|
def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
|
||||||
if target is not None:
|
if target is not None:
|
||||||
# We need to use `y` to set the model targets.
|
# We need to use `y` to set the model targets.
|
||||||
if training_utils_v1.has_tensors(target):
|
if training_utils.has_tensors(target):
|
||||||
target = training_utils_v1.cast_if_floating_dtype_and_mismatch(
|
target = training_utils.cast_if_floating_dtype_and_mismatch(
|
||||||
target, self.outputs)
|
target, self.outputs)
|
||||||
training_utils_v1.validate_input_types(
|
training_utils.validate_input_types(target, orig_target,
|
||||||
target, orig_target, allow_dict=False, field_name='target')
|
allow_dict=False, field_name='target')
|
||||||
if isinstance(target, (list, tuple)):
|
if isinstance(target, (list, tuple)):
|
||||||
all_inputs += list(target)
|
all_inputs += list(target)
|
||||||
else:
|
else:
|
||||||
@ -2631,7 +2628,7 @@ class Model(training_lib.Model):
|
|||||||
input_shape = (None,) + tuple(inputs.as_list()[1:])
|
input_shape = (None,) + tuple(inputs.as_list()[1:])
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
# We assert that the first layer is a FeatureLayer.
|
# We assert that the first layer is a FeatureLayer.
|
||||||
if not training_utils_v1.is_feature_layer(self.layers[0]):
|
if not training_utils.is_feature_layer(self.layers[0]):
|
||||||
raise ValueError('Passing a dictionary input to a Sequential Model '
|
raise ValueError('Passing a dictionary input to a Sequential Model '
|
||||||
'which doesn\'t have FeatureLayer as the first layer'
|
'which doesn\'t have FeatureLayer as the first layer'
|
||||||
' is an error.')
|
' is an error.')
|
||||||
@ -2646,7 +2643,7 @@ class Model(training_lib.Model):
|
|||||||
|
|
||||||
# On-the-fly setting of symbolic model inputs (either by using the tensor
|
# On-the-fly setting of symbolic model inputs (either by using the tensor
|
||||||
# provided, or by creating a placeholder if Numpy data was provided).
|
# provided, or by creating a placeholder if Numpy data was provided).
|
||||||
model_inputs = training_utils_v1.ModelInputs(inputs)
|
model_inputs = training_utils.ModelInputs(inputs)
|
||||||
inputs = model_inputs.get_symbolic_inputs()
|
inputs = model_inputs.get_symbolic_inputs()
|
||||||
self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
|
self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
|
||||||
self.input_names = model_inputs.get_input_names()
|
self.input_names = model_inputs.get_input_names()
|
||||||
@ -2670,7 +2667,7 @@ class Model(training_lib.Model):
|
|||||||
# data adapter since it assumes nest.flatten ordering.
|
# data adapter since it assumes nest.flatten ordering.
|
||||||
outputs = nest.flatten(outputs)
|
outputs = nest.flatten(outputs)
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
self.output_names = training_utils_v1.generic_output_names(outputs)
|
self.output_names = training_utils.generic_output_names(outputs)
|
||||||
# TODO(scottzhu): Should we cleanup the self._training_endpoints here?
|
# TODO(scottzhu): Should we cleanup the self._training_endpoints here?
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user