Fix last partial batch loss regression in 2.2
PiperOrigin-RevId: 307666011 Change-Id: I4ede295280b78e18b5b8b52f0c211d5c0a7913e2
This commit is contained in:
parent
81734a896e
commit
4f17f35bef
@ -192,6 +192,7 @@ class LossesContainer(Container):
|
|||||||
|
|
||||||
loss_values = [] # Used for gradient calculation.
|
loss_values = [] # Used for gradient calculation.
|
||||||
loss_metric_values = [] # Used for loss metric calculation.
|
loss_metric_values = [] # Used for loss metric calculation.
|
||||||
|
batch_dim = None
|
||||||
zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
|
zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
|
||||||
self._per_output_metrics)
|
self._per_output_metrics)
|
||||||
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
|
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
|
||||||
@ -207,8 +208,11 @@ class LossesContainer(Container):
|
|||||||
# Correct for the `Mean` loss metrics counting each replica as a batch.
|
# Correct for the `Mean` loss metrics counting each replica as a batch.
|
||||||
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
|
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
|
||||||
loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
|
loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
|
||||||
|
|
||||||
|
if batch_dim is None:
|
||||||
|
batch_dim = array_ops.shape(y_t)[0]
|
||||||
if metric_obj is not None:
|
if metric_obj is not None:
|
||||||
metric_obj.update_state(loss_metric_value)
|
metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)
|
||||||
|
|
||||||
if loss_weight is not None:
|
if loss_weight is not None:
|
||||||
loss_value *= loss_weight
|
loss_value *= loss_weight
|
||||||
@ -232,7 +236,8 @@ class LossesContainer(Container):
|
|||||||
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
|
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
|
||||||
loss_metric_values)
|
loss_metric_values)
|
||||||
total_loss_metric_value = math_ops.add_n(loss_metric_values)
|
total_loss_metric_value = math_ops.add_n(loss_metric_values)
|
||||||
self._loss_metric.update_state(total_loss_metric_value)
|
self._loss_metric.update_state(
|
||||||
|
total_loss_metric_value, sample_weight=batch_dim)
|
||||||
|
|
||||||
loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
|
loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
|
||||||
total_loss = math_ops.add_n(loss_values)
|
total_loss = math_ops.add_n(loss_values)
|
||||||
|
@ -47,15 +47,14 @@ def get_multi_io_model():
|
|||||||
|
|
||||||
def custom_generator_multi_io(sample_weights=None):
|
def custom_generator_multi_io(sample_weights=None):
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
num_samples = 4
|
num_samples = 5
|
||||||
inputs = np.asarray([[1.], [2.], [3.], [4.]])
|
inputs = np.asarray([[1.], [2.], [3.], [4.], [5.]])
|
||||||
targets_1 = np.asarray([[2.], [4.], [6.], [8.]])
|
targets_1 = np.asarray([[2.], [4.], [6.], [8.], [10.]])
|
||||||
targets_2 = np.asarray([[1.], [2.], [3.], [4.]])
|
targets_2 = np.asarray([[1.], [2.], [3.], [4.], [5.]])
|
||||||
i = 0
|
start = 0
|
||||||
while True:
|
while True:
|
||||||
batch_index = i * batch_size % num_samples
|
if start > num_samples:
|
||||||
i += 1
|
start = 0
|
||||||
start = batch_index
|
|
||||||
end = start + batch_size
|
end = start + batch_size
|
||||||
x = [inputs[start:end], inputs[start:end]]
|
x = [inputs[start:end], inputs[start:end]]
|
||||||
y = [targets_1[start:end], targets_2[start:end]]
|
y = [targets_1[start:end], targets_2[start:end]]
|
||||||
@ -63,6 +62,7 @@ def custom_generator_multi_io(sample_weights=None):
|
|||||||
sw = nest.map_structure(lambda w: w[start:end], sample_weights)
|
sw = nest.map_structure(lambda w: w[start:end], sample_weights)
|
||||||
else:
|
else:
|
||||||
sw = None
|
sw = None
|
||||||
|
start = end
|
||||||
yield x, y, sw
|
yield x, y, sw
|
||||||
|
|
||||||
|
|
||||||
@ -84,97 +84,103 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestMetricsCorrectnessMultiIO, self).setUp()
|
super(TestMetricsCorrectnessMultiIO, self).setUp()
|
||||||
self.x = np.asarray([[1.], [2.], [3.], [4.]])
|
self.x = np.asarray([[1.], [2.], [3.], [4.], [5.]])
|
||||||
self.y1 = np.asarray([[2.], [4.], [6.], [8.]])
|
self.y1 = np.asarray([[2.], [4.], [6.], [8.], [10.]])
|
||||||
self.y2 = np.asarray([[1.], [2.], [3.], [4.]])
|
self.y2 = np.asarray([[1.], [2.], [3.], [4.], [5.]])
|
||||||
self.sample_weight_1 = np.asarray([2., 3., 4., 5.])
|
self.sample_weight_1 = np.asarray([2., 3., 4., 5., 6.])
|
||||||
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
|
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5, 3.])
|
||||||
|
|
||||||
# y_true_1 = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
|
# y_true_1 = [[2.], [4.], [6.], [8.], [10.]]
|
||||||
# y_true_2 = [[1.], [2.], [3.], [4.]], y_pred = [[3.], [6.], [9.], [12.]]
|
# y_pred_1 = [[3.], [6.], [9.], [12.], [15.]]
|
||||||
|
# y_true_2 = [[1.], [2.], [3.], [4.], [5.]]
|
||||||
|
# y_pred_2 = [[3.], [6.], [9.], [12.], [15.]]
|
||||||
|
|
||||||
# Weighted metric `output_1`:
|
# Weighted metric `output_1`:
|
||||||
# Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
|
# Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
|
||||||
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
|
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5) +
|
||||||
# = 130
|
# ((15 - 10)^2 * 6)
|
||||||
# Count = (2 + 3) + (4 + 5)
|
# = 280
|
||||||
# Result = 9.2857141
|
# Count = (2 + 3) + (4 + 5) + 6 = 20
|
||||||
|
# Result = 14
|
||||||
|
|
||||||
# Weighted metric `output_2`:
|
# Weighted metric `output_2`:
|
||||||
# Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
|
# Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
|
||||||
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5)
|
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5) +
|
||||||
# = 140
|
# (15 - 5)^2 * 3.0
|
||||||
# Count = (3.5 + 2.5) + (1.5 + 0.5)
|
# = 440
|
||||||
# Result = 17.5
|
# Count = (3.5 + 2.5) + (1.5 + 0.5) + 3.0 = 11.0
|
||||||
|
# Result = 40
|
||||||
|
|
||||||
# Loss `output_1` with weights:
|
# Loss `output_1` with weights:
|
||||||
# Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
|
# Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
|
||||||
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
|
# ((9 - 6)^2 * 4 + (12 - 8)^2 * 5) +
|
||||||
# = 130
|
# ((15 - 10)^2 * 6)
|
||||||
# Count = 2 + 2
|
# = 280
|
||||||
# Result = 32.5
|
# Count = 2 + 2 + 1
|
||||||
|
# Result = 56
|
||||||
|
|
||||||
# Loss `output_1` without weights/Metric `output_1`:
|
# Loss `output_1` without weights/Metric `output_1`:
|
||||||
# Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30
|
# Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) + (15 - 10)^2
|
||||||
# Count = 2 + 2
|
# = 55
|
||||||
# Result = 7.5
|
# Count = 2 + 2 + 1
|
||||||
|
# Result = 11
|
||||||
|
|
||||||
# Loss `output_2` with weights:
|
# Loss `output_2` with weights:
|
||||||
# Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
|
# Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
|
||||||
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5)
|
# ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5) +
|
||||||
# = 140
|
# (15 - 5)^2 * 3.0
|
||||||
# Count = 2 + 2
|
# = 440
|
||||||
# Result = 35
|
# Count = 2 + 2 + 1
|
||||||
|
# Result = 88
|
||||||
|
|
||||||
# Loss `output_2` without weights/Metric `output_2`:
|
# Loss `output_2` without weights/Metric `output_2`:
|
||||||
# Total = ((3 - 1)^2 + (6 - 2)^2) + ((9 - 3)^2 + (12 - 4)^2) = 120
|
# Total = ((3 - 1)^2 + (6 - 2)^2) + ((9 - 3)^2 + (12 - 4)^2) + (15 - 5)^2
|
||||||
# Count = 2 + 2
|
# = 220
|
||||||
# Result = 30
|
# Count = 2 + 2 + 1
|
||||||
|
# Result = 44
|
||||||
|
|
||||||
# Total loss with weights = 32.5 + 35 = 67.5
|
# Total loss with weights = 56 + 88 = 144
|
||||||
# Total loss without weights = 7.5 + 30 = 37.5
|
# Total loss without weights = 11 + 44 = 55
|
||||||
|
|
||||||
self.wmse = 'mean_squared_error_2'
|
self.wmse = 'mean_squared_error_2'
|
||||||
self.expected_fit_result_with_weights = {
|
self.expected_fit_result_with_weights = {
|
||||||
'output_1_mean_squared_error': [7.5, 7.5],
|
'output_1_mean_squared_error': [11, 11],
|
||||||
'output_2_mean_squared_error': [30, 30],
|
'output_2_mean_squared_error': [44, 44],
|
||||||
'output_1_' + self.wmse: [9.286, 9.286],
|
'output_1_' + self.wmse: [14, 14],
|
||||||
'output_2_' + self.wmse: [17.5, 17.5],
|
'output_2_' + self.wmse: [40, 40],
|
||||||
'loss': [67.5, 67.5],
|
'loss': [144, 144],
|
||||||
'output_1_loss': [32.5, 32.5],
|
'output_1_loss': [56, 56],
|
||||||
'output_2_loss': [35, 35],
|
'output_2_loss': [88, 88],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.expected_fit_result_with_weights_output_2 = {
|
self.expected_fit_result_with_weights_output_2 = {
|
||||||
'output_1_mean_squared_error': [7.5, 7.5],
|
'output_1_mean_squared_error': [11, 11],
|
||||||
'output_2_mean_squared_error': [30, 30],
|
'output_2_mean_squared_error': [44, 44],
|
||||||
'output_1_' + self.wmse: [7.5, 7.5],
|
'output_1_' + self.wmse: [11, 11],
|
||||||
'output_2_' + self.wmse: [17.5, 17.5],
|
'output_2_' + self.wmse: [40, 40],
|
||||||
'loss': [42.5, 42.5],
|
'loss': [99, 99],
|
||||||
'output_1_loss': [7.5, 7.5],
|
'output_1_loss': [11, 11],
|
||||||
'output_2_loss': [35, 35],
|
'output_2_loss': [88, 88],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.expected_fit_result = {
|
self.expected_fit_result = {
|
||||||
'output_1_mean_squared_error': [7.5, 7.5],
|
'output_1_mean_squared_error': [11, 11],
|
||||||
'output_2_mean_squared_error': [30, 30],
|
'output_2_mean_squared_error': [44, 44],
|
||||||
'output_1_' + self.wmse: [7.5, 7.5],
|
'output_1_' + self.wmse: [11, 11],
|
||||||
'output_2_' + self.wmse: [30, 30],
|
'output_2_' + self.wmse: [44, 44],
|
||||||
'loss': [37.5, 37.5],
|
'loss': [55, 55],
|
||||||
'output_1_loss': [7.5, 7.5],
|
'output_1_loss': [11, 11],
|
||||||
'output_2_loss': [30, 30],
|
'output_2_loss': [44, 44],
|
||||||
}
|
}
|
||||||
|
|
||||||
# In the order: 'loss', 'output_1_loss', 'output_2_loss',
|
# In the order: 'loss', 'output_1_loss', 'output_2_loss',
|
||||||
# 'output_1_mean_squared_error', 'output_1_mean_squared_error_2',
|
# 'output_1_mean_squared_error', 'output_1_mean_squared_error_2',
|
||||||
# 'output_2_mean_squared_error', 'output_2_mean_squared_error_2'
|
# 'output_2_mean_squared_error', 'output_2_mean_squared_error_2'
|
||||||
self.expected_batch_result_with_weights = [
|
self.expected_batch_result_with_weights = [144, 56, 88, 11, 14, 44, 40]
|
||||||
67.5, 32.5, 35, 7.5, 9.286, 30, 17.5
|
|
||||||
]
|
|
||||||
self.expected_batch_result_with_weights_output_2 = [
|
self.expected_batch_result_with_weights_output_2 = [
|
||||||
42.5, 7.5, 35, 7.5, 7.5, 30, 17.5
|
99, 11, 88, 11, 11, 44, 40
|
||||||
]
|
]
|
||||||
self.expected_batch_result = [37.5, 7.5, 30, 7.5, 7.5, 30, 30]
|
self.expected_batch_result = [55, 11, 44, 11, 11, 44, 44]
|
||||||
|
|
||||||
def test_fit(self):
|
def test_fit(self):
|
||||||
model = self._get_compiled_multi_io_model()
|
model = self._get_compiled_multi_io_model()
|
||||||
@ -291,7 +297,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
|
|||||||
def test_fit_generator(self):
|
def test_fit_generator(self):
|
||||||
model = self._get_compiled_multi_io_model()
|
model = self._get_compiled_multi_io_model()
|
||||||
history = model.fit_generator(
|
history = model.fit_generator(
|
||||||
custom_generator_multi_io(), steps_per_epoch=2, epochs=2)
|
custom_generator_multi_io(), steps_per_epoch=3, epochs=2)
|
||||||
for key, value in self.expected_fit_result.items():
|
for key, value in self.expected_fit_result.items():
|
||||||
self.assertAllClose(history.history[key], value, 1e-3)
|
self.assertAllClose(history.history[key], value, 1e-3)
|
||||||
|
|
||||||
@ -300,7 +306,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
|
|||||||
history = model.fit_generator(
|
history = model.fit_generator(
|
||||||
custom_generator_multi_io(
|
custom_generator_multi_io(
|
||||||
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
||||||
steps_per_epoch=2,
|
steps_per_epoch=3,
|
||||||
epochs=2)
|
epochs=2)
|
||||||
for key, value in self.expected_fit_result_with_weights.items():
|
for key, value in self.expected_fit_result_with_weights.items():
|
||||||
self.assertAllClose(history.history[key], value, 1e-3)
|
self.assertAllClose(history.history[key], value, 1e-3)
|
||||||
@ -309,14 +315,14 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
|
|||||||
history = model.fit_generator(
|
history = model.fit_generator(
|
||||||
custom_generator_multi_io(
|
custom_generator_multi_io(
|
||||||
sample_weights={'output_2': self.sample_weight_2}),
|
sample_weights={'output_2': self.sample_weight_2}),
|
||||||
steps_per_epoch=2,
|
steps_per_epoch=3,
|
||||||
epochs=2)
|
epochs=2)
|
||||||
for key, value in self.expected_fit_result_with_weights_output_2.items():
|
for key, value in self.expected_fit_result_with_weights_output_2.items():
|
||||||
self.assertAllClose(history.history[key], value, 1e-3)
|
self.assertAllClose(history.history[key], value, 1e-3)
|
||||||
|
|
||||||
def test_eval_generator(self):
|
def test_eval_generator(self):
|
||||||
model = self._get_compiled_multi_io_model()
|
model = self._get_compiled_multi_io_model()
|
||||||
eval_result = model.evaluate_generator(custom_generator_multi_io(), steps=2)
|
eval_result = model.evaluate_generator(custom_generator_multi_io(), steps=3)
|
||||||
self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
|
self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
|
||||||
|
|
||||||
def test_eval_generator_with_sample_weight(self):
|
def test_eval_generator_with_sample_weight(self):
|
||||||
@ -324,7 +330,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
|
|||||||
eval_result = model.evaluate_generator(
|
eval_result = model.evaluate_generator(
|
||||||
custom_generator_multi_io(
|
custom_generator_multi_io(
|
||||||
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
||||||
steps=2)
|
steps=3)
|
||||||
self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
|
self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
|
||||||
1e-3)
|
1e-3)
|
||||||
|
|
||||||
@ -332,7 +338,7 @@ class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
|
|||||||
eval_result = model.evaluate_generator(
|
eval_result = model.evaluate_generator(
|
||||||
custom_generator_multi_io(
|
custom_generator_multi_io(
|
||||||
sample_weights={'output_2': self.sample_weight_2}),
|
sample_weights={'output_2': self.sample_weight_2}),
|
||||||
steps=2)
|
steps=3)
|
||||||
self.assertAllClose(eval_result,
|
self.assertAllClose(eval_result,
|
||||||
self.expected_batch_result_with_weights_output_2, 1e-3)
|
self.expected_batch_result_with_weights_output_2, 1e-3)
|
||||||
|
|
||||||
@ -549,7 +555,7 @@ class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
|
@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
@parameterized.parameters([
|
@parameterized.parameters([
|
||||||
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE,
|
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
loss_reduction.ReductionV2.AUTO,
|
loss_reduction.ReductionV2.AUTO,
|
||||||
@ -567,29 +573,34 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestOutputLossMetrics, self).setUp()
|
super(TestOutputLossMetrics, self).setUp()
|
||||||
self.x = np.asarray([[1.], [2.], [3.], [4.]])
|
self.x = np.asarray([[1.], [2.], [3.], [4.], [5.]])
|
||||||
self.y1 = np.asarray([[2.], [4.], [6.], [8.]])
|
self.y1 = np.asarray([[2.], [4.], [6.], [8.], [10.]])
|
||||||
self.y2 = np.asarray([[1.], [2.], [3.], [4.]])
|
self.y2 = np.asarray([[1.], [2.], [3.], [4.], [5.]])
|
||||||
self.sample_weight_1 = np.asarray([2., 3., 4., 5.])
|
self.sample_weight_1 = np.asarray([2., 3., 4., 5., 6.])
|
||||||
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
|
self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5, 3.])
|
||||||
|
|
||||||
# y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
|
# y_true_1 = [[2.], [4.], [6.], [8.], [10.]]
|
||||||
|
# y_pred_1 = [[3.], [6.], [9.], [12.], [15.]]
|
||||||
|
# y_true_2 = [[1.], [2.], [3.], [4.], [5.]]
|
||||||
|
# y_pred_2 = [[3.], [6.], [9.], [12.], [15.]]
|
||||||
|
|
||||||
# Loss `output_1`:
|
# Loss `output_1`:
|
||||||
# Per-sample weighted losses
|
# Per-sample weighted losses
|
||||||
# Batch 1 = [(3 - 2)^2 * 2, (6 - 4)^2 * 3)] = [2, 12]
|
# Batch 1 = [(3 - 2)^2 * 2, (6 - 4)^2 * 3)] = [2, 12]
|
||||||
# Batch 2 = [((9 - 6)^2 * 4, (12 - 8)^2 * 5)] = [36, 80]
|
# Batch 2 = [((9 - 6)^2 * 4, (12 - 8)^2 * 5)] = [36, 80]
|
||||||
|
# Batch 3 = [(15 - 10)^2 * 6] = [150]
|
||||||
|
|
||||||
# Result (reduction=SUM) = ((2 + 12) + (36 + 80))/2 = 65
|
# Result (reduction=SUM) = ((2 + 12)*2 + (36 + 80)*2 + 150) / 5 = 82
|
||||||
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 130 / 4 = 32.5
|
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 280 / 5 = 56
|
||||||
|
|
||||||
# Loss `output_2`:
|
# Loss `output_2`:
|
||||||
# Per-sample weighted losses
|
# Per-sample weighted losses
|
||||||
# Batch 1 = [(3 - 1)^2 * 3.5, (6 - 2)^2 * 2.5)] = [14, 40]
|
# Batch 1 = [(3 - 1)^2 * 3.5, (6 - 2)^2 * 2.5)] = [14, 40]
|
||||||
# Batch 2 = [(9 - 3)^2 * 1.5, (12 - 4)^2 * 0.5)] = [54, 32]
|
# Batch 2 = [(9 - 3)^2 * 1.5, (12 - 4)^2 * 0.5)] = [54, 32]
|
||||||
|
# Batch 3 = [(15 - 5)^2 * 3] = [300]
|
||||||
|
|
||||||
# Result (reduction=SUM) = ((14 + 40) + (54 + 32))/2 = 70
|
# Result (reduction=SUM) = ((14 + 40)*2 + (54 + 32)*2 + 300) / 5 = 116
|
||||||
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 140 / 4 = 35
|
# Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 440 / 5 = 88
|
||||||
|
|
||||||
# When reduction is 'NONE' loss value that is passed to the optimizer will
|
# When reduction is 'NONE' loss value that is passed to the optimizer will
|
||||||
# be vector loss but what is reported is a scalar, which is an average of
|
# be vector loss but what is reported is a scalar, which is an average of
|
||||||
@ -598,18 +609,18 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
# Total loss = Output_loss_1 + Output_loss_2
|
# Total loss = Output_loss_1 + Output_loss_2
|
||||||
|
|
||||||
sum_over_batch_size_fit_result = {
|
sum_over_batch_size_fit_result = {
|
||||||
'loss': [67.5, 67.5],
|
'loss': [144, 144],
|
||||||
'output_1_loss': [32.5, 32.5],
|
'output_1_loss': [56, 56],
|
||||||
'output_2_loss': [35, 35],
|
'output_2_loss': [88, 88],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.expected_fit_result = {
|
self.expected_fit_result = {
|
||||||
loss_reduction.ReductionV2.NONE:
|
loss_reduction.ReductionV2.NONE:
|
||||||
sum_over_batch_size_fit_result,
|
sum_over_batch_size_fit_result,
|
||||||
loss_reduction.ReductionV2.SUM: {
|
loss_reduction.ReductionV2.SUM: {
|
||||||
'loss': [135, 135],
|
'loss': [198, 198],
|
||||||
'output_1_loss': [65, 65],
|
'output_1_loss': [82, 82],
|
||||||
'output_2_loss': [70, 70],
|
'output_2_loss': [116, 116],
|
||||||
},
|
},
|
||||||
loss_reduction.ReductionV2.AUTO:
|
loss_reduction.ReductionV2.AUTO:
|
||||||
sum_over_batch_size_fit_result,
|
sum_over_batch_size_fit_result,
|
||||||
@ -619,12 +630,16 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
# In the order: 'loss', 'output_1_loss', 'output_2_loss',
|
# In the order: 'loss', 'output_1_loss', 'output_2_loss',
|
||||||
self.expected_batch_result = {
|
self.expected_batch_result = {
|
||||||
loss_reduction.ReductionV2.NONE: [67.5, 32.5, 35],
|
loss_reduction.ReductionV2.NONE: [144, 56, 88],
|
||||||
loss_reduction.ReductionV2.SUM: [135, 65, 70],
|
loss_reduction.ReductionV2.SUM: [198, 82, 116],
|
||||||
loss_reduction.ReductionV2.AUTO: [67.5, 32.5, 35],
|
loss_reduction.ReductionV2.AUTO: [144, 56, 88],
|
||||||
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE: [67.5, 32.5, 35],
|
loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE: [144, 56, 88],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 2 + 12 + 36 + 80 + 150 = 280
|
||||||
|
# 14 + 40 + 54 + 32 + 300 = 440
|
||||||
|
self.expected_single_batch_result = [720, 280, 440]
|
||||||
|
|
||||||
def test_fit(self, reduction):
|
def test_fit(self, reduction):
|
||||||
model = self._get_compiled_multi_io_model(
|
model = self._get_compiled_multi_io_model(
|
||||||
loss=losses.MeanSquaredError(reduction=reduction))
|
loss=losses.MeanSquaredError(reduction=reduction))
|
||||||
@ -661,8 +676,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
expected_values = self.expected_batch_result[reduction]
|
expected_values = self.expected_batch_result[reduction]
|
||||||
if reduction == loss_reduction.ReductionV2.SUM:
|
if reduction == loss_reduction.ReductionV2.SUM:
|
||||||
# We are taking all the data as one batch, so undo the averaging here.
|
expected_values = self.expected_single_batch_result
|
||||||
expected_values = [x * 2 for x in self.expected_batch_result[reduction]]
|
|
||||||
self.assertAllClose(result, expected_values)
|
self.assertAllClose(result, expected_values)
|
||||||
|
|
||||||
def test_test_on_batch(self, reduction):
|
def test_test_on_batch(self, reduction):
|
||||||
@ -675,8 +689,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
})
|
})
|
||||||
expected_values = self.expected_batch_result[reduction]
|
expected_values = self.expected_batch_result[reduction]
|
||||||
if reduction == loss_reduction.ReductionV2.SUM:
|
if reduction == loss_reduction.ReductionV2.SUM:
|
||||||
# We are taking all the data as one batch, so undo the averaging here.
|
expected_values = self.expected_single_batch_result
|
||||||
expected_values = [x * 2 for x in self.expected_batch_result[reduction]]
|
|
||||||
self.assertAllClose(result, expected_values)
|
self.assertAllClose(result, expected_values)
|
||||||
|
|
||||||
def test_fit_generator(self, reduction):
|
def test_fit_generator(self, reduction):
|
||||||
@ -685,7 +698,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
history = model.fit_generator(
|
history = model.fit_generator(
|
||||||
custom_generator_multi_io(
|
custom_generator_multi_io(
|
||||||
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
||||||
steps_per_epoch=2,
|
steps_per_epoch=3,
|
||||||
epochs=2)
|
epochs=2)
|
||||||
for key, value in self.expected_fit_result[reduction].items():
|
for key, value in self.expected_fit_result[reduction].items():
|
||||||
self.assertAllClose(history.history[key], value)
|
self.assertAllClose(history.history[key], value)
|
||||||
@ -696,7 +709,7 @@ class TestOutputLossMetrics(keras_parameterized.TestCase):
|
|||||||
eval_result = model.evaluate_generator(
|
eval_result = model.evaluate_generator(
|
||||||
custom_generator_multi_io(
|
custom_generator_multi_io(
|
||||||
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
sample_weights=[self.sample_weight_1, self.sample_weight_2]),
|
||||||
steps=2)
|
steps=3)
|
||||||
self.assertAllClose(eval_result, self.expected_batch_result[reduction])
|
self.assertAllClose(eval_result, self.expected_batch_result[reduction])
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user