From 3cfba9571bcc4be237bfdfa3498c66073ae59280 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 8 Jun 2020 11:32:43 -0700 Subject: [PATCH] Loose the check for BN when momentum == 0. Fix https://github.com/tensorflow/tensorflow/issues/38459. PiperOrigin-RevId: 315317303 Change-Id: I814fdcddec94b13296cfabb2fb80e19e7103c234 --- tensorflow/python/ops/nn_impl.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index eec352b4e2e..cb028bfe1e0 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1615,16 +1615,11 @@ def fused_batch_norm( [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) """ - if is_training and exponential_avg_factor == 1.0: - if (mean is not None) or (variance is not None): - raise ValueError("Both 'mean' and 'variance' must be None when " - "is_training is True and " - "exponential_avg_factor == 1.0.") - else: - if (mean is None) or (variance is None): - raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when " - "is_training is False or " - "exponential_avg_factor != 1.0.") + if (not is_training or exponential_avg_factor != 1.0) and ( + (mean is None) or (variance is None)): + raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when " + "is_training is False or " + "exponential_avg_factor != 1.0.") x = ops.convert_to_tensor(x, name="input") scale = ops.convert_to_tensor(scale, name="scale") offset = ops.convert_to_tensor(offset, name="offset")