Fix partial batch issue for the numpy data in training_v2.
1. Update the data adapter to include the final partial batch information if it is known. 2. Update training_v2 to aggregate based on number of example rather than steps when there is a known partial batch. The callback/progress bar will also use that in a followup cl. PiperOrigin-RevId: 259782295
This commit is contained in:
parent
5e5b01c914
commit
7251a1efe4
@ -152,6 +152,14 @@ class DataAdapter(object):
|
||||
"""Whether the dataset has partial batch at the end."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def partial_batch_size(self):
|
||||
"""The size of the final partial batch for dataset.
|
||||
|
||||
Will return None if has_partial_batch is False or batch_size is None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TensorLikeDataAdapter(DataAdapter):
|
||||
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
|
||||
@ -196,6 +204,11 @@ class TensorLikeDataAdapter(DataAdapter):
|
||||
self._size = 1
|
||||
self._batch_size = num_samples
|
||||
self._has_partial_batch = False
|
||||
self._partial_batch_size = None
|
||||
if self._has_partial_batch:
|
||||
self._partial_batch_size = (
|
||||
num_samples - (self._size - 1) * self._batch_size)
|
||||
|
||||
self._dataset = dataset
|
||||
|
||||
def get_dataset(self):
|
||||
@ -210,6 +223,9 @@ class TensorLikeDataAdapter(DataAdapter):
|
||||
def has_partial_batch(self):
|
||||
return self._has_partial_batch
|
||||
|
||||
def partial_batch_size(self):
|
||||
return self._partial_batch_size
|
||||
|
||||
|
||||
class DatasetAdapter(DataAdapter):
|
||||
"""Adapter that handles `tf.data.Dataset`."""
|
||||
@ -243,6 +259,9 @@ class DatasetAdapter(DataAdapter):
|
||||
def has_partial_batch(self):
|
||||
return False
|
||||
|
||||
def partial_batch_size(self):
|
||||
return None
|
||||
|
||||
|
||||
class GeneratorDataAdapter(DataAdapter):
|
||||
"""Adapter that handles python generator."""
|
||||
@ -288,6 +307,9 @@ class GeneratorDataAdapter(DataAdapter):
|
||||
def has_partial_batch(self):
|
||||
return False
|
||||
|
||||
def partial_batch_size(self):
|
||||
return None
|
||||
|
||||
|
||||
class KerasSequenceAdapter(DataAdapter):
|
||||
"""Adapter that handles `keras.utils.Sequence`."""
|
||||
@ -331,6 +353,9 @@ class KerasSequenceAdapter(DataAdapter):
|
||||
def has_partial_batch(self):
|
||||
return False
|
||||
|
||||
def partial_batch_size(self):
|
||||
return None
|
||||
|
||||
|
||||
ALL_ADAPTER_CLS = [
|
||||
TensorLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
|
||||
|
@ -102,6 +102,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
||||
self.numpy_input, self.numpy_target, batch_size=4)
|
||||
self.assertEqual(adapter.get_size(), 13) # 50/4
|
||||
self.assertTrue(adapter.has_partial_batch())
|
||||
self.assertEqual(adapter.partial_batch_size(), 2)
|
||||
|
||||
def test_training_numpy(self):
|
||||
dataset = self.adapter_cls(
|
||||
@ -140,6 +141,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
||||
self.tensor_input, self.tensor_target, batch_size=4)
|
||||
self.assertEqual(adapter.get_size(), 13) # 50/4
|
||||
self.assertTrue(adapter.has_partial_batch())
|
||||
self.assertEqual(adapter.partial_batch_size(), 2)
|
||||
|
||||
|
||||
class DatasetAdapterTest(DataAdapterTestBase):
|
||||
@ -171,6 +173,7 @@ class DatasetAdapterTest(DataAdapterTestBase):
|
||||
def test_partial_batch(self):
|
||||
adapter = self.adapter_cls(self.dataset_input)
|
||||
self.assertFalse(adapter.has_partial_batch())
|
||||
self.assertIsNone(adapter.partial_batch_size())
|
||||
|
||||
|
||||
class GeneratorDataAdapterTest(DataAdapterTestBase):
|
||||
@ -202,6 +205,7 @@ class GeneratorDataAdapterTest(DataAdapterTestBase):
|
||||
def test_partial_batch(self):
|
||||
adapter = self.adapter_cls(self.generator_input)
|
||||
self.assertFalse(adapter.has_partial_batch())
|
||||
self.assertIsNone(adapter.partial_batch_size())
|
||||
|
||||
|
||||
class KerasSequenceAdapterTest(DataAdapterTestBase):
|
||||
@ -233,6 +237,7 @@ class KerasSequenceAdapterTest(DataAdapterTestBase):
|
||||
def test_partial_batch(self):
|
||||
adapter = self.adapter_cls(self.sequence_input)
|
||||
self.assertFalse(adapter.has_partial_batch())
|
||||
self.assertIsNone(adapter.partial_batch_size())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -61,7 +61,8 @@ def run_one_epoch(model,
|
||||
steps_per_epoch=None,
|
||||
mode=ModeKeys.TRAIN,
|
||||
training_context=None,
|
||||
total_epochs=None):
|
||||
total_epochs=None,
|
||||
partical_batch_size=None):
|
||||
"""Run the execution function with the data from iterator.
|
||||
|
||||
Given the dataset iterator and execution function, get the data from iterator
|
||||
@ -81,15 +82,26 @@ def run_one_epoch(model,
|
||||
total_epochs: the total number of epochs that will be run.
|
||||
Used when throw error when the iterator unexpectedly
|
||||
reaches its end.
|
||||
partical_batch_size: the size of the final batch if it is already known. It
|
||||
will be used to scale the loss value for the final batch.
|
||||
Returns:
|
||||
The loss and metric value from the model.
|
||||
"""
|
||||
# Only use the sample to count if there is a partial batch at the end.
|
||||
use_steps = not (partical_batch_size and batch_size and steps_per_epoch and
|
||||
steps_per_epoch == dataset_size)
|
||||
num_samples = None if use_steps else batch_size * (steps_per_epoch -
|
||||
1) + partical_batch_size
|
||||
|
||||
if mode == ModeKeys.PREDICT:
|
||||
aggregator = training_utils.OutputsAggregator(
|
||||
use_steps=True, steps=steps_per_epoch, batch_size=batch_size)
|
||||
use_steps=use_steps,
|
||||
steps=steps_per_epoch,
|
||||
num_samples=num_samples,
|
||||
batch_size=batch_size)
|
||||
else:
|
||||
aggregator = training_utils.MetricsAggregator(
|
||||
use_steps=True, steps=steps_per_epoch)
|
||||
use_steps=use_steps, steps=steps_per_epoch, num_samples=num_samples)
|
||||
callbacks = training_context.callbacks
|
||||
progbar = training_context.progbar
|
||||
|
||||
@ -143,7 +155,14 @@ def run_one_epoch(model,
|
||||
|
||||
if step == 0:
|
||||
aggregator.create(batch_outs)
|
||||
aggregator.aggregate(batch_outs)
|
||||
|
||||
if use_steps:
|
||||
aggregator.aggregate(batch_outs)
|
||||
else:
|
||||
aggregator.aggregate(
|
||||
batch_outs,
|
||||
batch_start=step * batch_size,
|
||||
batch_end=min((step + 1) * batch_size, num_samples))
|
||||
cbks.make_logs(model, batch_logs, batch_outs, mode)
|
||||
|
||||
training_context.callbacks._call_batch_hook(
|
||||
@ -286,7 +305,8 @@ class Loop(training_utils.TrainingLoop):
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
mode=ModeKeys.TRAIN,
|
||||
training_context=training_context,
|
||||
total_epochs=epochs)
|
||||
total_epochs=epochs,
|
||||
partical_batch_size=training_data_adapter.partial_batch_size())
|
||||
cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
|
||||
|
||||
# Evaluation
|
||||
@ -316,7 +336,9 @@ class Loop(training_utils.TrainingLoop):
|
||||
steps_per_epoch=validation_steps,
|
||||
mode=ModeKeys.TEST,
|
||||
training_context=eval_context,
|
||||
total_epochs=1)
|
||||
total_epochs=1,
|
||||
partical_batch_size=validation_adapter.partial_batch_size(
|
||||
))
|
||||
cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
|
||||
prefix='val_')
|
||||
|
||||
@ -389,7 +411,8 @@ class Loop(training_utils.TrainingLoop):
|
||||
steps_per_epoch=steps,
|
||||
mode=mode,
|
||||
training_context=training_context,
|
||||
total_epochs=1)
|
||||
total_epochs=1,
|
||||
partical_batch_size=adapter.partial_batch_size())
|
||||
cbks.make_logs(model, epoch_logs, result, mode)
|
||||
|
||||
if len(result) == 1:
|
||||
|
Loading…
Reference in New Issue
Block a user