Fix partial batch issue for the numpy data in training_v2.

1. Update the data adapter to include the final partial batch information if it is known.
2. Update training_v2 to aggregate based on number of example rather than steps when there is a known partial batch. The callback/progress bar will also use that in a followup cl.

PiperOrigin-RevId: 259782295
This commit is contained in:
Scott Zhu 2019-07-24 11:34:50 -07:00 committed by TensorFlower Gardener
parent 5e5b01c914
commit 7251a1efe4
3 changed files with 60 additions and 7 deletions

View File

@ -152,6 +152,14 @@ class DataAdapter(object):
"""Whether the dataset has partial batch at the end."""
raise NotImplementedError
@abc.abstractmethod
def partial_batch_size(self):
"""The size of the final partial batch for dataset.
Will return None if has_partial_batch is False or batch_size is None.
"""
raise NotImplementedError
class TensorLikeDataAdapter(DataAdapter):
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
@ -196,6 +204,11 @@ class TensorLikeDataAdapter(DataAdapter):
self._size = 1
self._batch_size = num_samples
self._has_partial_batch = False
self._partial_batch_size = None
if self._has_partial_batch:
self._partial_batch_size = (
num_samples - (self._size - 1) * self._batch_size)
self._dataset = dataset
def get_dataset(self):
@ -210,6 +223,9 @@ class TensorLikeDataAdapter(DataAdapter):
def has_partial_batch(self):
return self._has_partial_batch
def partial_batch_size(self):
return self._partial_batch_size
class DatasetAdapter(DataAdapter):
"""Adapter that handles `tf.data.Dataset`."""
@ -243,6 +259,9 @@ class DatasetAdapter(DataAdapter):
def has_partial_batch(self):
return False
def partial_batch_size(self):
return None
class GeneratorDataAdapter(DataAdapter):
"""Adapter that handles python generator."""
@ -288,6 +307,9 @@ class GeneratorDataAdapter(DataAdapter):
def has_partial_batch(self):
return False
def partial_batch_size(self):
return None
class KerasSequenceAdapter(DataAdapter):
"""Adapter that handles `keras.utils.Sequence`."""
@ -331,6 +353,9 @@ class KerasSequenceAdapter(DataAdapter):
def has_partial_batch(self):
return False
def partial_batch_size(self):
return None
ALL_ADAPTER_CLS = [
TensorLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,

View File

@ -102,6 +102,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.numpy_input, self.numpy_target, batch_size=4)
self.assertEqual(adapter.get_size(), 13) # 50/4
self.assertTrue(adapter.has_partial_batch())
self.assertEqual(adapter.partial_batch_size(), 2)
def test_training_numpy(self):
dataset = self.adapter_cls(
@ -140,6 +141,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.tensor_input, self.tensor_target, batch_size=4)
self.assertEqual(adapter.get_size(), 13) # 50/4
self.assertTrue(adapter.has_partial_batch())
self.assertEqual(adapter.partial_batch_size(), 2)
class DatasetAdapterTest(DataAdapterTestBase):
@ -171,6 +173,7 @@ class DatasetAdapterTest(DataAdapterTestBase):
def test_partial_batch(self):
adapter = self.adapter_cls(self.dataset_input)
self.assertFalse(adapter.has_partial_batch())
self.assertIsNone(adapter.partial_batch_size())
class GeneratorDataAdapterTest(DataAdapterTestBase):
@ -202,6 +205,7 @@ class GeneratorDataAdapterTest(DataAdapterTestBase):
def test_partial_batch(self):
adapter = self.adapter_cls(self.generator_input)
self.assertFalse(adapter.has_partial_batch())
self.assertIsNone(adapter.partial_batch_size())
class KerasSequenceAdapterTest(DataAdapterTestBase):
@ -233,6 +237,7 @@ class KerasSequenceAdapterTest(DataAdapterTestBase):
def test_partial_batch(self):
adapter = self.adapter_cls(self.sequence_input)
self.assertFalse(adapter.has_partial_batch())
self.assertIsNone(adapter.partial_batch_size())
if __name__ == '__main__':

View File

@ -61,7 +61,8 @@ def run_one_epoch(model,
steps_per_epoch=None,
mode=ModeKeys.TRAIN,
training_context=None,
total_epochs=None):
total_epochs=None,
partical_batch_size=None):
"""Run the execution function with the data from iterator.
Given the dataset iterator and execution function, get the data from iterator
@ -81,15 +82,26 @@ def run_one_epoch(model,
total_epochs: the total number of epochs that will be run.
Used when throw error when the iterator unexpectedly
reaches its end.
partical_batch_size: the size of the final batch if it is already known. It
will be used to scale the loss value for the final batch.
Returns:
The loss and metric value from the model.
"""
# Only use the sample to count if there is a partial batch at the end.
use_steps = not (partical_batch_size and batch_size and steps_per_epoch and
steps_per_epoch == dataset_size)
num_samples = None if use_steps else batch_size * (steps_per_epoch -
1) + partical_batch_size
if mode == ModeKeys.PREDICT:
aggregator = training_utils.OutputsAggregator(
use_steps=True, steps=steps_per_epoch, batch_size=batch_size)
use_steps=use_steps,
steps=steps_per_epoch,
num_samples=num_samples,
batch_size=batch_size)
else:
aggregator = training_utils.MetricsAggregator(
use_steps=True, steps=steps_per_epoch)
use_steps=use_steps, steps=steps_per_epoch, num_samples=num_samples)
callbacks = training_context.callbacks
progbar = training_context.progbar
@ -143,7 +155,14 @@ def run_one_epoch(model,
if step == 0:
aggregator.create(batch_outs)
aggregator.aggregate(batch_outs)
if use_steps:
aggregator.aggregate(batch_outs)
else:
aggregator.aggregate(
batch_outs,
batch_start=step * batch_size,
batch_end=min((step + 1) * batch_size, num_samples))
cbks.make_logs(model, batch_logs, batch_outs, mode)
training_context.callbacks._call_batch_hook(
@ -286,7 +305,8 @@ class Loop(training_utils.TrainingLoop):
steps_per_epoch=steps_per_epoch,
mode=ModeKeys.TRAIN,
training_context=training_context,
total_epochs=epochs)
total_epochs=epochs,
partical_batch_size=training_data_adapter.partial_batch_size())
cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
# Evaluation
@ -316,7 +336,9 @@ class Loop(training_utils.TrainingLoop):
steps_per_epoch=validation_steps,
mode=ModeKeys.TEST,
training_context=eval_context,
total_epochs=1)
total_epochs=1,
partical_batch_size=validation_adapter.partial_batch_size(
))
cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
prefix='val_')
@ -389,7 +411,8 @@ class Loop(training_utils.TrainingLoop):
steps_per_epoch=steps,
mode=mode,
training_context=training_context,
total_epochs=1)
total_epochs=1,
partical_batch_size=adapter.partial_batch_size())
cbks.make_logs(model, epoch_logs, result, mode)
if len(result) == 1: