From 31d195921b8bd944c25832e03c181ff2c9cf3f93 Mon Sep 17 00:00:00 2001 From: Paul Andrey Date: Tue, 17 Dec 2019 15:02:32 +0100 Subject: [PATCH] Enabled adding non-default floatX metrics and losses to keras models. --- tensorflow/python/keras/engine/base_layer_utils.py | 2 +- tensorflow/python/keras/engine/network.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index f16f7d16284..60c30c8c23c 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -44,7 +44,7 @@ def create_mean_metric(value, name=None): # import keras will import base_layer and then this module, and metric relies # on base_layer, which result into a cyclic dependency. from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top - metric_obj = metrics_module.Mean(name=name) + metric_obj = metrics_module.Mean(name=name, dtype=value.dtype) return metric_obj, metric_obj(value) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 7aebdb24e51..1ad1f8ddbb6 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1477,7 +1477,8 @@ class Network(base_layer.Layer): new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss]) # Losses must be keyed on inputs no matter what in order to be supported in # DistributionStrategy. - add_loss_layer = base_layer.AddLoss(unconditional=False) + add_loss_layer = base_layer.AddLoss(unconditional=False, + dtype=symbolic_loss.dtype) add_loss_layer(symbolic_loss) new_nodes.extend(add_loss_layer.inbound_nodes) new_layers.append(add_loss_layer) @@ -1485,7 +1486,8 @@ class Network(base_layer.Layer): def _graph_network_add_metric(self, value, aggregation, name): new_nodes, new_layers = _map_subgraph_network(self.inputs, [value]) - add_metric_layer = base_layer.AddMetric(aggregation, name) + add_metric_layer = base_layer.AddMetric(aggregation, name, + dtype=value.dtype) add_metric_layer(value) new_nodes.extend(add_metric_layer.inbound_nodes) new_layers.append(add_metric_layer)