Fix ema type mismatch for special ema value.

PiperOrigin-RevId: 348118379
Change-Id: Ife5f789e42e30b08384c354a50c3fec3cdcefd10
This commit is contained in:
Mingxing Tan 2020-12-17 16:52:49 -08:00 committed by TensorFlower Gardener
parent 347146c5f6
commit 8e6e0c1ac9

View File

@ -473,7 +473,8 @@ class ExponentialMovingAverage(object):
self._averages[var.ref()] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
decay = ops.convert_to_tensor(
self._decay, dtype=dtypes.float32, name="decay")
if self._num_updates is not None:
num_updates = math_ops.cast(
self._num_updates, dtypes.float32, name="num_updates")