Loose the check for BN when momentum == 0.

Fix https://github.com/tensorflow/tensorflow/issues/38459.

PiperOrigin-RevId: 315317303
Change-Id: I814fdcddec94b13296cfabb2fb80e19e7103c234
This commit is contained in:
Scott Zhu 2020-06-08 11:32:43 -07:00 committed by TensorFlower Gardener
parent be20584437
commit 3cfba9571b

View File

@ -1615,16 +1615,11 @@ def fused_batch_norm(
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
"""
if is_training and exponential_avg_factor == 1.0:
if (mean is not None) or (variance is not None):
raise ValueError("Both 'mean' and 'variance' must be None when "
"is_training is True and "
"exponential_avg_factor == 1.0.")
else:
if (mean is None) or (variance is None):
raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
"is_training is False or "
"exponential_avg_factor != 1.0.")
if (not is_training or exponential_avg_factor != 1.0) and (
(mean is None) or (variance is None)):
raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
"is_training is False or "
"exponential_avg_factor != 1.0.")
x = ops.convert_to_tensor(x, name="input")
scale = ops.convert_to_tensor(scale, name="scale")
offset = ops.convert_to_tensor(offset, name="offset")