Enabled adding non-default floatX metrics and losses to keras models.

This commit is contained in:
Paul Andrey 2019-12-17 15:02:32 +01:00
parent 1768c8f2fa
commit 31d195921b
2 changed files with 5 additions and 3 deletions
tensorflow/python/keras/engine

View File

@ -44,7 +44,7 @@ def create_mean_metric(value, name=None):
# import keras will import base_layer and then this module, and metric relies
# on base_layer, which result into a cyclic dependency.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
metric_obj = metrics_module.Mean(name=name)
metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
return metric_obj, metric_obj(value)

View File

@ -1477,7 +1477,8 @@ class Network(base_layer.Layer):
new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
# Losses must be keyed on inputs no matter what in order to be supported in
# DistributionStrategy.
add_loss_layer = base_layer.AddLoss(unconditional=False)
add_loss_layer = base_layer.AddLoss(unconditional=False,
dtype=symbolic_loss.dtype)
add_loss_layer(symbolic_loss)
new_nodes.extend(add_loss_layer.inbound_nodes)
new_layers.append(add_loss_layer)
@ -1485,7 +1486,8 @@ class Network(base_layer.Layer):
def _graph_network_add_metric(self, value, aggregation, name):
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
add_metric_layer = base_layer.AddMetric(aggregation, name)
add_metric_layer = base_layer.AddMetric(aggregation, name,
dtype=value.dtype)
add_metric_layer(value)
new_nodes.extend(add_metric_layer.inbound_nodes)
new_layers.append(add_metric_layer)