Fix last partial batch loss regression in 2.2

PiperOrigin-RevId: 307666011
Change-Id: I4ede295280b78e18b5b8b52f0c211d5c0a7913e2
This commit is contained in:
Pavithra Vijay 2020-04-21 13:09:50 -07:00 committed by TensorFlower Gardener
parent 81734a896e
commit 4f17f35bef
2 changed files with 119 additions and 101 deletions

View File

@ -192,6 +192,7 @@ class LossesContainer(Container):
loss_values = [] # Used for gradient calculation.
loss_metric_values = [] # Used for loss metric calculation.
batch_dim = None
zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
self._per_output_metrics)
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
@ -207,8 +208,11 @@ class LossesContainer(Container):
# Correct for the `Mean` loss metrics counting each replica as a batch.
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
if batch_dim is None:
batch_dim = array_ops.shape(y_t)[0]
if metric_obj is not None:
metric_obj.update_state(loss_metric_value)
metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)
if loss_weight is not None:
loss_value *= loss_weight
@ -232,7 +236,8 @@ class LossesContainer(Container):
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
loss_metric_values)
total_loss_metric_value = math_ops.add_n(loss_metric_values)
self._loss_metric.update_state(total_loss_metric_value)
self._loss_metric.update_state(
total_loss_metric_value, sample_weight=batch_dim)
loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
total_loss = math_ops.add_n(loss_values)

View File

@ -47,15 +47,14 @@ def get_multi_io_model():
def custom_generator_multi_io(sample_weights=None):
batch_size = 2
num_samples = 4
inputs = np.asarray([[1.], [2.], [3.], [4.]])
targets_1 = np.asarray([[2.], [4.], [6.], [8.]])
targets_2 = np.asarray([[1.], [2.], [3.], [4.]])
i = 0
num_samples = 5
inputs = np.asarray([[1.], [2.], [3.], [4.], [5.]])
targets_1 = np.asarray([[2.], [4.], [6.], [8.], [10.]])
targets_2 = np.asarray([[1.], [2.], [3.], [4.], [5.]])
start = 0
while True:
batch_index = i * batch_size % num_samples
i += 1
start = batch_index
if start > num_samples:
start = 0
end = start + batch_size
x = [inputs[start:end], inputs[start:end]]
y = [targets_1[start:end], targets_2[start:end]]
@ -63,6 +62,7 @@ def custom_generator_multi_io(sample_weights=None):
sw = nest.map_structure(lambda w: w[start:end], sample_weights)
else:
sw = None
start = end
yield x, y, sw
@ -84,97 +84,103 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
def setUp(self):
super(TestMetricsCorrectnessMultiIO, self).setUp()
self.x = np.asarray([[1.], [2.], [3.], [4.]])
self.y1 = np.asarray([[2.], [4.], [6.], [8.]])
self.y2 = np.asarray([[1.], [2.], [3.], [4.]])
self.sample_weight_1 = np.asarray([2., 3., 4., 5.])
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
self.x = np.asarray([[1.], [2.], [3.], [4.], [5.]])
self.y1 = np.asarray([[2.], [4.], [6.], [8.], [10.]])
self.y2 = np.asarray([[1.], [2.], [3.], [4.], [5.]])
self.sample_weight_1 = np.asarray([2., 3., 4., 5., 6.])
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5, 3.])
# y_true_1 = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
# y_true_2 = [[1.], [2.], [3.], [4.]], y_pred = [[3.], [6.], [9.], [12.]]
# y_true_1 = [[2.], [4.], [6.], [8.], [10.]]
# y_pred_1 = [[3.], [6.], [9.], [12.], [15.]]
# y_true_2 = [[1.], [2.], [3.], [4.], [5.]]
# y_pred_2 = [[3.], [6.], [9.], [12.], [15.]]
# Weighted metric `output_1`:
# Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
# = 130
# Count = (2 + 3) + (4 + 5)
# Result = 9.2857141
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5) +
# ((15 - 10)^2 * 6)
# = 280
# Count = (2 + 3) + (4 + 5) + 6 = 20
# Result = 14
# Weighted metric `output_2`:
# Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5)
# = 140
# Count = (3.5 + 2.5) + (1.5 + 0.5)
# Result = 17.5
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5) +
# (15 - 5)^2 * 3.0
# = 440
# Count = (3.5 + 2.5) + (1.5 + 0.5) + 3.0 = 11.0
# Result = 40
# Loss `output_1` with weights:
# Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
# = 130
# Count = 2 + 2
# Result = 32.5
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5) +
# ((15 - 10)^2 * 6)
# = 280
# Count = 2 + 2 + 1
# Result = 56
# Loss `output_1` without weights/Metric `output_1`:
# Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30
# Count = 2 + 2
# Result = 7.5
# Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) + (15 - 10)^2
# = 55
# Count = 2 + 2 + 1
# Result = 11
# Loss `output_2` with weights:
# Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5)
# = 140
# Count = 2 + 2
# Result = 35
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5) +
# (15 - 5)^2 * 3.0
# = 440
# Count = 2 + 2 + 1
# Result = 88
# Loss `output_2` without weights/Metric `output_2`:
# Total = ((3 - 1)^2 + (6 - 2)^2) + ((9 - 3)^2 + (12 - 4)^2) = 120
# Count = 2 + 2
# Result = 30
# Total = ((3 - 1)^2 + (6 - 2)^2) + ((9 - 3)^2 + (12 - 4)^2) + (15 - 5)^2
# = 220
# Count = 2 + 2 + 1
# Result = 44
# Total loss with weights = 32.5 + 35 = 67.5
# Total loss without weights = 7.5 + 30 = 37.5
# Total loss with weights = 56 + 88 = 144
# Total loss without weights = 11 + 44 = 55
self.wmse = 'mean_squared_error_2'
self.expected_fit_result_with_weights = {
'output_1_mean_squared_error': [7.5, 7.5],
'output_2_mean_squared_error': [30, 30],
'output_1_' + self.wmse: [9.286, 9.286],
'output_2_' + self.wmse: [17.5, 17.5],
'loss': [67.5, 67.5],
'output_1_loss': [32.5, 32.5],
'output_2_loss': [35, 35],
'output_1_mean_squared_error': [11, 11],
'output_2_mean_squared_error': [44, 44],
'output_1_' + self.wmse: [14, 14],
'output_2_' + self.wmse: [40, 40],
'loss': [144, 144],
'output_1_loss': [56, 56],
'output_2_loss': [88, 88],
}
self.expected_fit_result_with_weights_output_2 = {
'output_1_mean_squared_error': [7.5, 7.5],
'output_2_mean_squared_error': [30, 30],
'output_1_' + self.wmse: [7.5, 7.5],
'output_2_' + self.wmse: [17.5, 17.5],
'loss': [42.5, 42.5],
'output_1_loss': [7.5, 7.5],
'output_2_loss': [35, 35],
'output_1_mean_squared_error': [11, 11],
'output_2_mean_squared_error': [44, 44],
'output_1_' + self.wmse: [11, 11],
'output_2_' + self.wmse: [40, 40],
'loss': [99, 99],
'output_1_loss': [11, 11],
'output_2_loss': [88, 88],
}
self.expected_fit_result = {
'output_1_mean_squared_error': [7.5, 7.5],
'output_2_mean_squared_error': [30, 30],
'output_1_' + self.wmse: [7.5, 7.5],
'output_2_' + self.wmse: [30, 30],
'loss': [37.5, 37.5],
'output_1_loss': [7.5, 7.5],
'output_2_loss': [30, 30],
'output_1_mean_squared_error': [11, 11],
'output_2_mean_squared_error': [44, 44],
'output_1_' + self.wmse: [11, 11],
'output_2_' + self.wmse: [44, 44],
'loss': [55, 55],
'output_1_loss': [11, 11],
'output_2_loss': [44, 44],
}
# In the order: 'loss', 'output_1_loss', 'output_2_loss',
# 'output_1_mean_squared_error', 'output_1_mean_squared_error_2',
# 'output_2_mean_squared_error', 'output_2_mean_squared_error_2'
self.expected_batch_result_with_weights = [
67.5, 32.5, 35, 7.5, 9.286, 30, 17.5
]
self.expected_batch_result_with_weights = [144, 56, 88, 11, 14, 44, 40]
self.expected_batch_result_with_weights_output_2 = [
42.5, 7.5, 35, 7.5, 7.5, 30, 17.5
99, 11, 88, 11, 11, 44, 40
]
self.expected_batch_result = [37.5, 7.5, 30, 7.5, 7.5, 30, 30]
self.expected_batch_result = [55, 11, 44, 11, 11, 44, 44]
def test_fit(self):
model = self._get_compiled_multi_io_model()
@ -291,7 +297,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
def test_fit_generator(self):
model = self._get_compiled_multi_io_model()
history = model.fit_generator(
custom_generator_multi_io(), steps_per_epoch=2, epochs=2)
custom_generator_multi_io(), steps_per_epoch=3, epochs=2)
for key, value in self.expected_fit_result.items():
self.assertAllClose(history.history[key], value, 1e-3)
@ -300,7 +306,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
history = model.fit_generator(
custom_generator_multi_io(
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
steps_per_epoch=2,
steps_per_epoch=3,
epochs=2)
for key, value in self.expected_fit_result_with_weights.items():
self.assertAllClose(history.history[key], value, 1e-3)
@ -309,14 +315,14 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
history = model.fit_generator(
custom_generator_multi_io(
sample_weights={'output_2': self.sample_weight_2}),
steps_per_epoch=2,
steps_per_epoch=3,
epochs=2)
for key, value in self.expected_fit_result_with_weights_output_2.items():
self.assertAllClose(history.history[key], value, 1e-3)
def test_eval_generator(self):
model = self._get_compiled_multi_io_model()
eval_result = model.evaluate_generator(custom_generator_multi_io(), steps=2)
eval_result = model.evaluate_generator(custom_generator_multi_io(), steps=3)
self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
def test_eval_generator_with_sample_weight(self):
@ -324,7 +330,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
eval_result = model.evaluate_generator(
custom_generator_multi_io(
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
steps=2)
steps=3)
self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
1e-3)
@ -332,7 +338,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
eval_result = model.evaluate_generator(
custom_generator_multi_io(
sample_weights={'output_2': self.sample_weight_2}),
steps=2)
steps=3)
self.assertAllClose(eval_result,
self.expected_batch_result_with_weights_output_2, 1e-3)
@ -549,7 +555,7 @@ class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@parameterized.parameters([
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE,
loss_reduction.ReductionV2.AUTO,
@ -567,29 +573,34 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
def setUp(self):
super(TestOutputLossMetrics, self).setUp()
self.x = np.asarray([[1.], [2.], [3.], [4.]])
self.y1 = np.asarray([[2.], [4.], [6.], [8.]])
self.y2 = np.asarray([[1.], [2.], [3.], [4.]])
self.sample_weight_1 = np.asarray([2., 3., 4., 5.])
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
self.x = np.asarray([[1.], [2.], [3.], [4.], [5.]])
self.y1 = np.asarray([[2.], [4.], [6.], [8.], [10.]])
self.y2 = np.asarray([[1.], [2.], [3.], [4.], [5.]])
self.sample_weight_1 = np.asarray([2., 3., 4., 5., 6.])
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5, 3.])
# y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
# y_true_1 = [[2.], [4.], [6.], [8.], [10.]]
# y_pred_1 = [[3.], [6.], [9.], [12.], [15.]]
# y_true_2 = [[1.], [2.], [3.], [4.], [5.]]
# y_pred_2 = [[3.], [6.], [9.], [12.], [15.]]
# Loss `output_1`:
# Per-sample weighted losses
# Batch 1 = [(3 - 2)^2 * 2, (6 - 4)^2 * 3)] = [2, 12]
# Batch 2 = [((9 - 6)^2 * 4, (12 - 8)^2 * 5)] = [36, 80]
# Batch 3 = [(15 - 10)^2 * 6] = [150]
# Result (reduction=SUM) = ((2 + 12) + (36 + 80))/2 = 65
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 130 / 4 = 32.5
# Result (reduction=SUM) = ((2 + 12)*2 + (36 + 80)*2 + 150) / 5 = 82
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 280 / 5 = 56
# Loss `output_2`:
# Per-sample weighted losses
# Batch 1 = [(3 - 1)^2 * 3.5, (6 - 2)^2 * 2.5)] = [14, 40]
# Batch 2 = [(9 - 3)^2 * 1.5, (12 - 4)^2 * 0.5)] = [54, 32]
# Batch 3 = [(15 - 5)^2 * 3] = [300]
# Result (reduction=SUM) = ((14 + 40) + (54 + 32))/2 = 70
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 140 / 4 = 35
# Result (reduction=SUM) = ((14 + 40)*2 + (54 + 32)*2 + 300) / 5 = 116
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 440 / 5 = 88
# When reduction is 'NONE' loss value that is passed to the optimizer will
# be vector loss but what is reported is a scalar, which is an average of
@ -598,18 +609,18 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
# Total loss = Output_loss_1 + Output_loss_2
sum_over_batch_size_fit_result = {
'loss': [67.5, 67.5],
'output_1_loss': [32.5, 32.5],
'output_2_loss': [35, 35],
'loss': [144, 144],
'output_1_loss': [56, 56],
'output_2_loss': [88, 88],
}
self.expected_fit_result = {
loss_reduction.ReductionV2.NONE:
sum_over_batch_size_fit_result,
loss_reduction.ReductionV2.SUM: {
'loss': [135, 135],
'output_1_loss': [65, 65],
'output_2_loss': [70, 70],
'loss': [198, 198],
'output_1_loss': [82, 82],
'output_2_loss': [116, 116],
},
loss_reduction.ReductionV2.AUTO:
sum_over_batch_size_fit_result,
@ -619,12 +630,16 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
# In the order: 'loss', 'output_1_loss', 'output_2_loss',
self.expected_batch_result = {
loss_reduction.ReductionV2.NONE: [67.5, 32.5, 35],
loss_reduction.ReductionV2.SUM: [135, 65, 70],
loss_reduction.ReductionV2.AUTO: [67.5, 32.5, 35],
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE: [67.5, 32.5, 35],
loss_reduction.ReductionV2.NONE: [144, 56, 88],
loss_reduction.ReductionV2.SUM: [198, 82, 116],
loss_reduction.ReductionV2.AUTO: [144, 56, 88],
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE: [144, 56, 88],
}
# 2 + 12 + 36 + 80 + 150 = 280
# 14 + 40 + 54 + 32 + 300 = 440
self.expected_single_batch_result = [720, 280, 440]
def test_fit(self, reduction):
model = self._get_compiled_multi_io_model(
loss=losses.MeanSquaredError(reduction=reduction))
@ -661,8 +676,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
expected_values = self.expected_batch_result[reduction]
if reduction == loss_reduction.ReductionV2.SUM:
# We are taking all the data as one batch, so undo the averaging here.
expected_values = [x * 2 for x in self.expected_batch_result[reduction]]
expected_values = self.expected_single_batch_result
self.assertAllClose(result, expected_values)
def test_test_on_batch(self, reduction):
@ -675,8 +689,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
})
expected_values = self.expected_batch_result[reduction]
if reduction == loss_reduction.ReductionV2.SUM:
# We are taking all the data as one batch, so undo the averaging here.
expected_values = [x * 2 for x in self.expected_batch_result[reduction]]
expected_values = self.expected_single_batch_result
self.assertAllClose(result, expected_values)
def test_fit_generator(self, reduction):
@ -685,7 +698,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
history = model.fit_generator(
custom_generator_multi_io(
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
steps_per_epoch=2,
steps_per_epoch=3,
epochs=2)
for key, value in self.expected_fit_result[reduction].items():
self.assertAllClose(history.history[key], value)
@ -696,7 +709,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
eval_result = model.evaluate_generator(
custom_generator_multi_io(
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
steps=2)
steps=3)
self.assertAllClose(eval_result, self.expected_batch_result[reduction])