Enabled adding non-default floatX metrics and losses to keras models.
This commit is contained in:
parent
1768c8f2fa
commit
31d195921b
tensorflow/python/keras/engine
@ -44,7 +44,7 @@ def create_mean_metric(value, name=None):
|
||||
# import keras will import base_layer and then this module, and metric relies
|
||||
# on base_layer, which result into a cyclic dependency.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
metric_obj = metrics_module.Mean(name=name)
|
||||
metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
|
||||
return metric_obj, metric_obj(value)
|
||||
|
||||
|
||||
|
@ -1477,7 +1477,8 @@ class Network(base_layer.Layer):
|
||||
new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
|
||||
# Losses must be keyed on inputs no matter what in order to be supported in
|
||||
# DistributionStrategy.
|
||||
add_loss_layer = base_layer.AddLoss(unconditional=False)
|
||||
add_loss_layer = base_layer.AddLoss(unconditional=False,
|
||||
dtype=symbolic_loss.dtype)
|
||||
add_loss_layer(symbolic_loss)
|
||||
new_nodes.extend(add_loss_layer.inbound_nodes)
|
||||
new_layers.append(add_loss_layer)
|
||||
@ -1485,7 +1486,8 @@ class Network(base_layer.Layer):
|
||||
|
||||
def _graph_network_add_metric(self, value, aggregation, name):
|
||||
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
|
||||
add_metric_layer = base_layer.AddMetric(aggregation, name)
|
||||
add_metric_layer = base_layer.AddMetric(aggregation, name,
|
||||
dtype=value.dtype)
|
||||
add_metric_layer(value)
|
||||
new_nodes.extend(add_metric_layer.inbound_nodes)
|
||||
new_layers.append(add_metric_layer)
|
||||
|
Loading…
Reference in New Issue
Block a user