Loose the check for BN when momentum == 0.
Fix https://github.com/tensorflow/tensorflow/issues/38459. PiperOrigin-RevId: 315317303 Change-Id: I814fdcddec94b13296cfabb2fb80e19e7103c234
This commit is contained in:
parent
be20584437
commit
3cfba9571b
@ -1615,16 +1615,11 @@ def fused_batch_norm(
|
|||||||
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
|
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
|
||||||
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
|
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
|
||||||
"""
|
"""
|
||||||
if is_training and exponential_avg_factor == 1.0:
|
if (not is_training or exponential_avg_factor != 1.0) and (
|
||||||
if (mean is not None) or (variance is not None):
|
(mean is None) or (variance is None)):
|
||||||
raise ValueError("Both 'mean' and 'variance' must be None when "
|
raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
|
||||||
"is_training is True and "
|
"is_training is False or "
|
||||||
"exponential_avg_factor == 1.0.")
|
"exponential_avg_factor != 1.0.")
|
||||||
else:
|
|
||||||
if (mean is None) or (variance is None):
|
|
||||||
raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
|
|
||||||
"is_training is False or "
|
|
||||||
"exponential_avg_factor != 1.0.")
|
|
||||||
x = ops.convert_to_tensor(x, name="input")
|
x = ops.convert_to_tensor(x, name="input")
|
||||||
scale = ops.convert_to_tensor(scale, name="scale")
|
scale = ops.convert_to_tensor(scale, name="scale")
|
||||||
offset = ops.convert_to_tensor(offset, name="offset")
|
offset = ops.convert_to_tensor(offset, name="offset")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user