Merge pull request #35191 from pandrey-fr:fix-addmetrics-dtype

PiperOrigin-RevId: 290781323
Change-Id: Ibda387392f7c1ee7abdc85cdf1bc3e723efafb54
This commit is contained in:
TensorFlower Gardener 2020-01-21 11:14:57 -08:00
commit 6f3fc0867c
2 changed files with 5 additions and 3 deletions

View File

@ -44,7 +44,7 @@ def create_mean_metric(value, name=None):
# import keras will import base_layer and then this module, and metric relies
# on base_layer, which result into a cyclic dependency.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
metric_obj = metrics_module.Mean(name=name)
metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
return metric_obj, metric_obj(value)

View File

@ -1501,7 +1501,8 @@ class Network(base_layer.Layer):
new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
# Losses must be keyed on inputs no matter what in order to be supported in
# DistributionStrategy.
add_loss_layer = base_layer.AddLoss(unconditional=False)
add_loss_layer = base_layer.AddLoss(
unconditional=False, dtype=symbolic_loss.dtype)
add_loss_layer(symbolic_loss)
new_nodes.extend(add_loss_layer.inbound_nodes)
new_layers.append(add_loss_layer)
@ -1509,7 +1510,8 @@ class Network(base_layer.Layer):
def _graph_network_add_metric(self, value, aggregation, name):
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
add_metric_layer = base_layer.AddMetric(aggregation, name)
add_metric_layer = base_layer.AddMetric(
aggregation, name, dtype=value.dtype)
add_metric_layer(value)
new_nodes.extend(add_metric_layer.inbound_nodes)
new_layers.append(add_metric_layer)