From 8e6e0c1ac917c691de69318bbada08970adced22 Mon Sep 17 00:00:00 2001 From: Mingxing Tan Date: Thu, 17 Dec 2020 16:52:49 -0800 Subject: [PATCH] Fix ema type mismatch for special ema value. PiperOrigin-RevId: 348118379 Change-Id: Ife5f789e42e30b08384c354a50c3fec3cdcefd10 --- tensorflow/python/training/moving_averages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index b95e366aa38..768188f1cf7 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -473,7 +473,8 @@ class ExponentialMovingAverage(object): self._averages[var.ref()] = avg with ops.name_scope(self.name) as scope: - decay = ops.convert_to_tensor(self._decay, name="decay") + decay = ops.convert_to_tensor( + self._decay, dtype=dtypes.float32, name="decay") if self._num_updates is not None: num_updates = math_ops.cast( self._num_updates, dtypes.float32, name="num_updates")