Merge pull request #35191 from pandrey-fr:fix-addmetrics-dtype
PiperOrigin-RevId: 290781323 Change-Id: Ibda387392f7c1ee7abdc85cdf1bc3e723efafb54
This commit is contained in:
commit
6f3fc0867c
@ -44,7 +44,7 @@ def create_mean_metric(value, name=None):
|
||||
# import keras will import base_layer and then this module, and metric relies
|
||||
# on base_layer, which result into a cyclic dependency.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
metric_obj = metrics_module.Mean(name=name)
|
||||
metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
|
||||
return metric_obj, metric_obj(value)
|
||||
|
||||
|
||||
|
@ -1501,7 +1501,8 @@ class Network(base_layer.Layer):
|
||||
new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
|
||||
# Losses must be keyed on inputs no matter what in order to be supported in
|
||||
# DistributionStrategy.
|
||||
add_loss_layer = base_layer.AddLoss(unconditional=False)
|
||||
add_loss_layer = base_layer.AddLoss(
|
||||
unconditional=False, dtype=symbolic_loss.dtype)
|
||||
add_loss_layer(symbolic_loss)
|
||||
new_nodes.extend(add_loss_layer.inbound_nodes)
|
||||
new_layers.append(add_loss_layer)
|
||||
@ -1509,7 +1510,8 @@ class Network(base_layer.Layer):
|
||||
|
||||
def _graph_network_add_metric(self, value, aggregation, name):
|
||||
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
|
||||
add_metric_layer = base_layer.AddMetric(aggregation, name)
|
||||
add_metric_layer = base_layer.AddMetric(
|
||||
aggregation, name, dtype=value.dtype)
|
||||
add_metric_layer(value)
|
||||
new_nodes.extend(add_metric_layer.inbound_nodes)
|
||||
new_layers.append(add_metric_layer)
|
||||
|
Loading…
Reference in New Issue
Block a user