From 84f2ec1d60b5bb14a59ccef8f8fa7eb5a1096e8f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Feb 2020 14:27:38 -0800 Subject: [PATCH] Support computing exponential running mean and variance in fused_batch_norm. Simplify unit test code for fused_batch_norm and its gradient. PiperOrigin-RevId: 297451723 Change-Id: I3ad1db022848cec3cf45a5a560a8f25af75781d4 --- .../compiler/tests/fused_batchnorm_test.py | 89 ++-- .../compiler/tf2xla/kernels/batch_norm_op.cc | 28 +- tensorflow/core/framework/BUILD | 1 + tensorflow/core/framework/common_shape_fns.cc | 24 +- .../core/framework/common_shape_fns_test.cc | 4 +- tensorflow/core/kernels/BUILD | 2 +- .../core/kernels/fused_batch_norm_op.cc | 33 +- tensorflow/core/ops/nn_ops_test.cc | 58 ++- tensorflow/python/BUILD | 2 +- .../python/ops/nn_fused_batchnorm_test.py | 399 +++++++----------- tensorflow/python/ops/nn_grad.py | 5 +- tensorflow/python/ops/nn_impl.py | 96 +++-- tensorflow/stream_executor/cuda/cuda_dnn.cc | 6 +- .../tools/api/golden/v1/tensorflow.nn.pbtxt | 2 +- 14 files changed, 388 insertions(+), 361 deletions(-) diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index ad8368a2bfb..6a9076e9be8 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests import test_utils from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -34,10 +35,18 @@ DATA_FORMATS = ( ("_data_format_NCHW", "NCHW"), ) +DATA_FORMATS_AND_AVG_FACTORS = ( + ("_data_format_NHWC_no_averaging", "NHWC", 1.0), + ("_data_format_NHWC_averaging", "NHWC", 0.6), + ("_data_format_NCHW_no_averaging", "NCHW", 1.0), + ("_data_format_NCHW_averaging", "NCHW", 0.6), +) + class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): - def _reference_training(self, x, scale, offset, epsilon, data_format): + def _reference_training(self, x, scale, offset, old_mean, old_var, epsilon, + exponential_avg_factor, data_format): if data_format != "NHWC": raise ValueError("data_format must be NHWC, got %s." % data_format) x_square = x * x @@ -49,6 +58,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): factor = element_count / max(element_count - 1, 1) corrected_var = var * factor normalized = (x - mean) / np.sqrt(var + epsilon) + if exponential_avg_factor != 1.0: + mean = (1.0 - + exponential_avg_factor) * old_mean + exponential_avg_factor * mean + corrected_var = (1.0 - exponential_avg_factor + ) * old_var + exponential_avg_factor * corrected_var return (normalized * scale + offset), mean, var, corrected_var def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format): @@ -81,9 +95,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + exponential_avg_factor = 1.0 data_format_src = "NHWC" y_ref, mean_ref, var_ref, _ = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format_src) + x_val, scale_val, offset_val, None, None, epsilon, + exponential_avg_factor, data_format_src) with self.session() as sess, self.test_scope(): # To avoid constant folding @@ -114,7 +130,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): }) self.assertAllClose(y_val, y_ref_converted, atol=1e-3) - def _testLearning(self, use_gradient_checker, data_format): + def _testLearning(self, use_gradient_checker, data_format, + exponential_avg_factor): + if not compat.forward_compatible(2020, 3, + 6) and exponential_avg_factor != 1.0: + self.skipTest("running average not available.") channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] @@ -122,13 +142,14 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) - var_val = np.random.random_sample(scale_shape).astype(np.float32) + var_val_corr = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" # When in training mode, fused_batchnorm applies an implicit Bessel's # correction. So we have to use the corrected variance here, as well. y_ref, mean_ref, _, var_ref_corr = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format_src) + x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon, + exponential_avg_factor, data_format_src) with self.session() as sess, self.test_scope(): # To avoid constant folding @@ -142,15 +163,38 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") + if exponential_avg_factor == 1.0: + old_mean = None + old_var = None + else: + old_mean = array_ops.placeholder( + np.float32, shape=scale_shape, name="old_mean") + old_var = array_ops.placeholder( + np.float32, shape=scale_shape, name="old_var") y, mean, var = nn.fused_batch_norm( t_val, scale, offset, - mean=None, - variance=None, + mean=old_mean, + variance=old_var, epsilon=epsilon, + exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=True) + if exponential_avg_factor == 1.0: + feed_dict = { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val, + } + else: + feed_dict = { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val, + old_mean: mean_val, + old_var: var_val_corr + } # Check gradient. if use_gradient_checker: err = gradient_checker.compute_gradient_error( @@ -158,29 +202,22 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): x_val_converted.shape, y, x_val_converted.shape, - extra_feed_dict={ - t_val: x_val_converted, - scale: scale_val, - offset: offset_val - }) + extra_feed_dict=feed_dict) self.assertLess(err, 1e-3) - y_val, mean_val, var_val = sess.run([y, mean, var], { - t_val: x_val_converted, - scale: scale_val, - offset: offset_val - }) - self.assertAllClose(mean_val, mean_ref, atol=1e-3) - self.assertAllClose(y_val, y_ref_converted, atol=1e-3) - self.assertAllClose(var_val, var_ref_corr, atol=1e-3) + y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict) + self.assertAllClose(y_tf, y_ref_converted, atol=1e-3) + self.assertAllClose(mean_tf, mean_ref, atol=1e-3) + self.assertAllClose(var_tf, var_ref_corr, atol=1e-3) - @parameterized.named_parameters(*DATA_FORMATS) - def testLearning(self, data_format): - self._testLearning(False, data_format) + @parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS) + def testLearning(self, data_format, exponential_avg_factor): + self._testLearning(False, data_format, exponential_avg_factor) - @parameterized.named_parameters(*DATA_FORMATS) - def testLearningWithGradientChecker(self, data_format): - self._testLearning(True, data_format) + @parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS) + def testLearningWithGradientChecker(self, data_format, + exponential_avg_factor): + self._testLearning(True, data_format, exponential_avg_factor) @parameterized.named_parameters(*DATA_FORMATS) def testGradientTraining(self, data_format): diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 37f2e868a3d..fcc93eb0e8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -37,6 +37,8 @@ class FusedBatchNormOp : public XlaOpKernel { : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_)); string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES( @@ -105,7 +107,6 @@ class FusedBatchNormOp : public XlaOpKernel { ctx->SetOutput(0, converted); } - ctx->SetOutput(1, xla::GetTupleElement(output, 1)); xla::XlaOp variance = xla::GetTupleElement(output, 2); // Apply Bessel's correction. int total_input_size = ctx->InputShape(0).num_elements(); @@ -121,7 +122,7 @@ class FusedBatchNormOp : public XlaOpKernel { if (input_shape.num_elements() == 0) { auto status_or_output_shape = b->GetShape(corrected); OP_REQUIRES_OK(ctx, status_or_output_shape.status()); - + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); ctx->SetOutput( kVarianceOutputIndex, xla::Broadcast( @@ -130,7 +131,27 @@ class FusedBatchNormOp : public XlaOpKernel { status_or_output_shape.ValueOrDie().dimensions()))); } else { - ctx->SetOutput(2, corrected); + if (exponential_avg_factor_ == 1.0f) { + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); + ctx->SetOutput(2, corrected); + } else { + xla::XlaOp old_mean = ctx->Input(3); + xla::XlaOp alpha = + xla::ScalarLike(old_mean, 1.0f - exponential_avg_factor_); + xla::XlaOp beta = xla::ScalarLike(old_mean, exponential_avg_factor_); + // new_running_mean = alpha * old_mean + beta * batch_mean. + xla::XlaOp new_running_mean = + xla::Add(xla::Mul(old_mean, alpha), + xla::Mul(xla::GetTupleElement(output, 1), beta)); + ctx->SetOutput(1, new_running_mean); + + xla::XlaOp old_variance = ctx->Input(4); + xla::XlaOp new_running_variance = xla::Add( + xla::Mul(old_variance, alpha), xla::Mul(corrected, beta)); + // new_running_variance = alpha * old_variance + beta * + // batch_variance. + ctx->SetOutput(2, new_running_variance); + } } // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved @@ -175,6 +196,7 @@ class FusedBatchNormOp : public XlaOpKernel { float epsilon_; TensorFormat data_format_; bool is_training_; + float exponential_avg_factor_; bool add_side_input_; bool apply_relu_; bool is_on_gpu_; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index bf7d8f1e2fb..5cf57e82357 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -787,6 +787,7 @@ cc_library( "//tensorflow/core/util:padding", "//tensorflow/core/util:tensor_format", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 06831d5f516..2d39be1379e 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include "tensorflow/core/framework/common_shape_fns.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -1083,7 +1083,11 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { bool is_training; TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); - int number_inputs = (is_training) ? 3 : 5; + float exponential_avg_factor; + if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) { + exponential_avg_factor = 1.0f; // default value + } + int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5; string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; @@ -1176,16 +1180,8 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { c->set_output(0, x_backprop); c->set_output(1, c->Vector(channel_dim)); c->set_output(2, c->Vector(channel_dim)); - // Set the correct shapes for reserve_spaces - // so that gradients can be performed when - // the op is in a symbolic condition. - if (is_training) { - c->set_output(3, c->Vector(0)); - c->set_output(4, c->Vector(0)); - } else { - c->set_output(3, c->Vector(channel_dim)); - c->set_output(4, c->Vector(channel_dim)); - } + c->set_output(3, c->Vector(0)); + c->set_output(4, c->Vector(0)); return Status::OK(); } @@ -2326,7 +2322,7 @@ Status SparseReduceShapeFn(InferenceContext* c) { auto axes_vec = axes_tensor->flat(); int64 ndims = shape_vec.size(); - std::unordered_set axes; + absl::flat_hash_set axes; for (int i = 0; i < axes_vec.size(); i++) { axes.insert((axes_vec(i) + ndims) % ndims); } diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 68c448c8007..f2755c8917e 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) { .Finalize(&op.node_def)); // Channels are not multiple of 4. - INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]"); + INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]"); - INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]", + INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?"); } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index efa11dedba4..a8e2268b732 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1923,7 +1923,7 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "fused_batch_norm_op_test", size = "small", srcs = ["fused_batch_norm_op_test.cc"], diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index afe3e621fcf..00ac9be6dcd 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -929,6 +929,28 @@ struct FusedBatchNorm { workspace_allocator.reset( new functor::CudnnBatchNormAllocatorInTemp(context)); } + if (!batch_mean->SharesBufferWith(estimated_mean) && + exponential_avg_factor != 1.0f) { + OP_REQUIRES( + context, + stream + ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr, + estimated_mean.NumElements() * sizeof(U)) + .ok(), + errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " + "from device")); + } + if (!batch_var->SharesBufferWith(estimated_variance) && + exponential_avg_factor != 1.0f) { + OP_REQUIRES( + context, + stream + ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr, + estimated_variance.NumElements() * sizeof(U)) + .ok(), + errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " + "from device")); + } bool cudnn_launch_status = stream ->ThenBatchNormalizationForward( @@ -1263,11 +1285,11 @@ class FusedBatchNormOpBase : public OpKernel { OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, x.shape(), &y)); Tensor* batch_mean = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(1, scale.shape(), &batch_mean)); + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {3}, 1, scale.shape(), &batch_mean)); Tensor* batch_var = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(2, scale.shape(), &batch_var)); + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {4}, 2, scale.shape(), &batch_var)); Tensor* saved_mean = nullptr; OP_REQUIRES_OK(context, context->allocate_output(3, scale.shape(), &saved_mean)); @@ -1400,12 +1422,9 @@ class FusedBatchNormGradOpBase : public OpKernel { Tensor* placeholder_1 = nullptr; OP_REQUIRES_OK( context, context->allocate_output(3, TensorShape({0}), &placeholder_1)); - functor::SetZeroFunctor f; - f(context->eigen_device(), placeholder_1->flat()); Tensor* placeholder_2 = nullptr; OP_REQUIRES_OK( context, context->allocate_output(4, TensorShape({0}), &placeholder_2)); - f(context->eigen_device(), placeholder_2->flat()); // If input is empty, set gradients w.r.t scale/offset to zero. if (x.shape().num_elements() == 0) { diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index 289b9530556..b53f7624d96 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -181,7 +181,8 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) { TEST(NNOpsTest, FusedBatchNorm_ShapeFn) { ShapeInferenceTestOp op("FusedBatchNorm"); - auto set_op = [&op](bool is_training, string data_format) { + auto set_op = [&op](bool is_training, float exponential_avg_factor, + string data_format) { TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm") .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) @@ -190,10 +191,11 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) { .Input(FakeInput(DT_FLOAT)) .Attr("data_format", data_format) .Attr("is_training", is_training) + .Attr("exponential_avg_factor", exponential_avg_factor) .Finalize(&op.node_def)); }; - set_op(true, "NHWC"); + set_op(true, 1.0, "NHWC"); // Test rank errors. INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); @@ -207,7 +209,21 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) { "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];" "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]"); - set_op(true, "NCHW"); + set_op(true, 0.5, "NHWC"); + // Test rank errors. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?"); + // Channel dim of first input is merged with the single dim in other 4 inputs. + INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]"); + INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]"); + INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]"); + INFER_OK(op, "[1,2,3,4];[4];[4];?;?", + "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];" + "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];" + "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]"); + + set_op(true, 1.0, "NCHW"); // Test rank errors. INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); @@ -221,7 +237,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) { "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];" "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]"); - set_op(false, "NHWC"); + set_op(false, 1.0, "NHWC"); // Test rank errors. INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); @@ -239,7 +255,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) { "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];" "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]"); - set_op(false, "NCHW"); + set_op(false, 1.0, "NCHW"); // Test rank errors. INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); @@ -271,22 +287,6 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) { .Finalize(&op.node_def)); }; - set_op("NHWC"); - // Test rank errors. - INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); - INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?"); - INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?"); - INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?"); - INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]"); - // Channel dim of first input is merged with the single dim in other 4 inputs. - INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]"); - INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]"); - INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]"); - INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]"); - INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]", - "[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];" - "[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]"); - set_op("NCHW"); // Test rank errors. INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); @@ -302,6 +302,22 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) { INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]", "[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];" "[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]"); + + set_op("NHWC"); + // Test rank errors. + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); + INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?"); + INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]"); + // Channel dim of first input is merged with the single dim in other 4 inputs. + INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]"); + INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]"); + INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]"); + INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]"); + INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]", + "[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];" + "[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]"); } TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) { diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6f7ec6389a3..1f11b892776 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5015,7 +5015,7 @@ cuda_py_test( size = "large", srcs = ["ops/nn_fused_batchnorm_test.py"], python_version = "PY3", - shard_count = 16, + shard_count = 24, deps = [ ":array_ops", ":client_testlib", diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 130034fbeec..5236d9049ee 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -59,6 +60,7 @@ class BatchNormalizationTest(test.TestCase): scale_shape, scale_dtype, use_gpu=True, + exponential_avg_factor=1.0, data_format='NHWC'): np.random.seed(1) x_val = np.random.random_sample(x_shape).astype(x_dtype) @@ -81,6 +83,7 @@ class BatchNormalizationTest(test.TestCase): mean=mean, variance=var, epsilon=epsilon, + exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=False) y_val = self.evaluate(y) @@ -92,17 +95,37 @@ class BatchNormalizationTest(test.TestCase): atol = 2e-3 if x_dtype == np.float16 else 1e-3 self.assertAllClose(y_ref, y_val, atol=atol) - def _training_ref(self, x, scale, offset, epsilon, data_format): + def _running_mean(self, old_mean, new_val, factor): + if factor == 1.0: + return new_val + else: + return (1.0 - factor) * old_mean + factor * new_val + + def _training_ref(self, x, scale, offset, old_mean, old_var, + exponential_avg_factor, epsilon, data_format): if data_format not in ['NHWC', 'NCHW']: raise ValueError('data_format must be NCHW or NHWC, ' 'got %s.' % data_format) if data_format == 'NCHW': x = array_ops.transpose(x, [0, 2, 3, 1]) - mean, var = nn_impl.moments( + batch_mean, batch_var = nn_impl.moments( math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False) - y = self._batch_norm(x, mean, var, offset, scale, epsilon) + + y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon) if data_format == 'NCHW': y = array_ops.transpose(y, [0, 3, 1, 2]) + + # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as + # the denominator in the formula to calculate variance, while + # tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in. + sample_size = math_ops.cast( + array_ops.size(x) / array_ops.size(scale), scale.dtype) + batch_var_corrected = batch_var * sample_size / ( + math_ops.maximum(sample_size - 1.0, 1.0)) + + mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor) + var = self._running_mean(old_var, batch_var_corrected, + exponential_avg_factor) return self.evaluate(y), self.evaluate(mean), self.evaluate(var) def _test_training(self, @@ -111,11 +134,19 @@ class BatchNormalizationTest(test.TestCase): scale_shape, scale_dtype, use_gpu=True, + exponential_avg_factor=1.0, data_format='NHWC'): np.random.seed(1) x_val = np.random.random_sample(x_shape).astype(x_dtype) scale_val = np.random.random_sample(scale_shape).astype(scale_dtype) offset_val = np.random.random_sample(scale_shape).astype(scale_dtype) + if exponential_avg_factor == 1.0: + old_mean_val = None + old_var_val = None + else: + old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype) + old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype) + with self.cached_session(use_gpu=use_gpu) as sess: x = constant_op.constant(x_val, name='x') scale = constant_op.constant(scale_val, name='scale') @@ -125,20 +156,20 @@ class BatchNormalizationTest(test.TestCase): x, scale, offset, + mean=old_mean_val, + variance=old_var_val, epsilon=epsilon, + exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=True) y_val, mean_val, var_val = self.evaluate([y, mean, var]) - y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon, - data_format) + y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, + old_mean_val, old_var_val, + exponential_avg_factor, + epsilon, data_format) y_atol = 2e-3 if x_dtype == np.float16 else 1e-3 self.assertAllClose(y_ref, y_val, atol=y_atol) self.assertAllClose(mean_ref, mean_val, atol=1e-3) - # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as - # the denominator in the formula to calculate variance, while - # tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in. - sample_size = x_val.size / scale_val.size - var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0)) self.assertAllClose(var_ref, var_val, atol=1e-3) def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape): @@ -184,6 +215,7 @@ class BatchNormalizationTest(test.TestCase): scale_shape, scale_dtype, use_gpu=True, + exponential_avg_factor=1.0, data_format='NHWC', is_training=True): np.random.seed(1) @@ -195,7 +227,7 @@ class BatchNormalizationTest(test.TestCase): x = constant_op.constant(x_val, name='x') scale = constant_op.constant(scale_val, name='scale') offset = constant_op.constant(offset_val, name='offset') - if is_training: + if is_training and exponential_avg_factor == 1.0: pop_mean = None pop_var = None else: @@ -207,6 +239,7 @@ class BatchNormalizationTest(test.TestCase): offset, mean=pop_mean, variance=pop_var, + exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=is_training) if x_dtype != np.float16: @@ -224,6 +257,7 @@ class BatchNormalizationTest(test.TestCase): mean=pop_mean, variance=pop_var, data_format=data_format, + exponential_avg_factor=exponential_avg_factor, is_training=is_training) err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32, x_shape) @@ -244,6 +278,7 @@ class BatchNormalizationTest(test.TestCase): scale_shape, scale_dtype, use_gpu=True, + exponential_avg_factor=1.0, data_format='NHWC', is_training=True, err_tolerance=1e-3): @@ -258,7 +293,7 @@ class BatchNormalizationTest(test.TestCase): grad_y = constant_op.constant(grad_y_val, name='grad_y') scale = constant_op.constant(scale_val, name='scale') offset = constant_op.constant(offset_val, name='offset') - if is_training: + if is_training and exponential_avg_factor == 1.0: pop_mean = None pop_var = None else: @@ -270,6 +305,7 @@ class BatchNormalizationTest(test.TestCase): offset, mean=pop_mean, variance=pop_var, + exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=is_training) grad_x, grad_scale, grad_offset = gradients_impl.gradients( @@ -311,6 +347,7 @@ class BatchNormalizationTest(test.TestCase): offset, mean=pop_mean, variance=pop_var, + exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=is_training) grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients( @@ -339,282 +376,146 @@ class BatchNormalizationTest(test.TestCase): self.assertLess(err_grad_x_2, err_tolerance) self.assertLess(err_grad_scale, err_tolerance) + def _runtests(self, x_shape, is_training, gradient_test=False): + use_gpu_vals = [False] + if test.is_gpu_available(cuda_only=True): + use_gpu_vals += [True] + factors = [ + 1.0, + ] + if compat.forward_compatible(2020, 3, 6): + factors += [ + 0.6, + ] + for dtype in [np.float16, np.float32]: + for use_gpu in use_gpu_vals: + for data_format in ['NHWC', 'NCHW']: + if data_format == 'NHWC': + scale_shape = x_shape[-1:] + else: + scale_shape = x_shape[1:2] + for exponential_avg_factor in factors: + if gradient_test: + self._test_gradient( + x_shape, + dtype, + scale_shape, + np.float32, + use_gpu=use_gpu, + data_format=data_format, + is_training=is_training, + exponential_avg_factor=exponential_avg_factor) + else: + if is_training: + self._test_training( + x_shape, + dtype, + scale_shape, + np.float32, + use_gpu=use_gpu, + data_format=data_format, + exponential_avg_factor=exponential_avg_factor) + else: + self._test_inference( + x_shape, + dtype, + scale_shape, + np.float32, + use_gpu=use_gpu, + data_format=data_format, + exponential_avg_factor=exponential_avg_factor) + def testInferenceShape1(self): x_shape = [1, 1, 6, 1] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_inference( - x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC') - self._test_inference( - x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW') - self._test_inference( - x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC') - self._test_inference( - x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW') + self._runtests(x_shape, False) def testInferenceShape2(self): x_shape = [1, 1, 6, 2] - if test.is_gpu_available(cuda_only=True): - for dtype in [np.float16, np.float32]: - self._test_inference( - x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC') - self._test_inference( - x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC') + self._runtests(x_shape, False) def testInferenceShape3(self): x_shape = [1, 2, 1, 6] - if test.is_gpu_available(cuda_only=True): - for dtype in [np.float16, np.float32]: - self._test_inference( - x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW') + self._runtests(x_shape, False) def testInferenceShape4(self): x_shape = [27, 131, 127, 6] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_inference( - x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW') - self._test_inference( - x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC') - self._test_inference( - x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW') - self._test_inference( - x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + self._runtests(x_shape, False) def testInferenceShape5(self): x_shape = [0, 131, 127, 6] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_inference( - x_shape, - dtype, [131], - np.float32, - use_gpu=True, - data_format='NCHW') - self._test_inference( - x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC') - self._test_inference( - x_shape, - dtype, [131], - np.float32, - use_gpu=False, - data_format='NCHW') - self._test_inference( - x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + self._runtests(x_shape, False) def testTrainingShape1(self): x_shape = [1, 1, 6, 1] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_training( - x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC') - self._test_training( - x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW') - self._test_training( - x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC') - self._test_training( - x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW') + self._runtests(x_shape, True) def testTrainingShape2(self): x_shape = [1, 1, 6, 2] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_training( - x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC') - self._test_training( - x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC') + self._runtests(x_shape, True) def testTrainingShape3(self): x_shape = [1, 2, 1, 6] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_training( - x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW') - self._test_training( - x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW') + self._runtests(x_shape, True) def testTrainingShape4(self): x_shape = [27, 131, 127, 6] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_training( - x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW') - self._test_training( - x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC') - self._test_training( - x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW') - self._test_training( - x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + self._runtests(x_shape, True) @test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.') def testTrainingShape5(self): x_shape = [0, 131, 127, 6] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_training( - x_shape, - dtype, [131], - np.float32, - use_gpu=True, - data_format='NCHW') - self._test_training( - x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC') - self._test_training( - x_shape, - dtype, [131], - np.float32, - use_gpu=False, - data_format='NCHW') - self._test_training( - x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + self._runtests(x_shape, True) + + def testBatchNormGradInferenceShape1(self): + x_shape = [1, 1, 6, 1] + self._runtests(x_shape, is_training=False, gradient_test=True) @test_util.run_deprecated_v1 - def testBatchNormGradShape1(self): - for is_training in [True, False]: - x_shape = [1, 1, 6, 1] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_gradient( - x_shape, - dtype, [1], - np.float32, - use_gpu=True, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [1], - np.float32, - use_gpu=True, - data_format='NCHW', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [1], - np.float32, - use_gpu=False, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [1], - np.float32, - use_gpu=False, - data_format='NCHW', - is_training=is_training) + def testBatchNormGradInferenceShape2(self): + x_shape = [1, 1, 6, 2] + self._runtests(x_shape, is_training=False, gradient_test=True) @test_util.run_deprecated_v1 - def testBatchNormGradShape2(self): - for is_training in [True, False]: - x_shape = [1, 1, 6, 2] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_gradient( - x_shape, - dtype, [2], - np.float32, - use_gpu=True, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [2], - np.float32, - use_gpu=False, - data_format='NHWC', - is_training=is_training) + def testBatchNormGradInferenceShape3(self): + x_shape = [1, 2, 1, 6] + self._runtests(x_shape, is_training=False, gradient_test=True) @test_util.run_deprecated_v1 - def testBatchNormGradShape3(self): - for is_training in [True, False]: - x_shape = [1, 2, 1, 6] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_gradient( - x_shape, - dtype, [2], - np.float32, - use_gpu=True, - data_format='NCHW', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [2], - np.float32, - use_gpu=False, - data_format='NCHW', - is_training=is_training) - - @test_util.run_deprecated_v1 - def testBatchNormGradShape4(self): - for is_training in [True, False]: - x_shape = [5, 7, 11, 4] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_gradient( - x_shape, - dtype, [7], - np.float32, - use_gpu=True, - data_format='NCHW', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [4], - np.float32, - use_gpu=True, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [4], - np.float32, - use_gpu=False, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [7], - np.float32, - use_gpu=False, - data_format='NCHW', - is_training=is_training) + def testBatchNormGradInferenceShape4(self): + x_shape = [5, 7, 11, 4] + self._runtests(x_shape, is_training=False, gradient_test=True) @test_util.run_deprecated_v1 @test_util.disable_xla('This test never passed for XLA') - def testBatchNormGradShape5(self): - for is_training in [True, False]: - x_shape = [0, 7, 11, 4] - for dtype in [np.float16, np.float32]: - if test.is_gpu_available(cuda_only=True): - self._test_gradient( - x_shape, - dtype, [7], - np.float32, - use_gpu=True, - data_format='NCHW', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [4], - np.float32, - use_gpu=True, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [4], - np.float32, - use_gpu=False, - data_format='NHWC', - is_training=is_training) - self._test_gradient( - x_shape, - dtype, [7], - np.float32, - use_gpu=False, - data_format='NCHW', - is_training=is_training) + def testBatchNormGradInferenceShape5(self): + x_shape = [0, 7, 11, 4] + self._runtests(x_shape, is_training=False, gradient_test=True) + + @test_util.run_deprecated_v1 + def testBatchNormGradTrainingShape1(self): + x_shape = [1, 1, 6, 1] + self._runtests(x_shape, is_training=True, gradient_test=True) + + @test_util.run_deprecated_v1 + def testBatchNormGradTrainingShape2(self): + x_shape = [1, 1, 6, 2] + self._runtests(x_shape, is_training=True, gradient_test=True) + + @test_util.run_deprecated_v1 + def testBatchNormGradTrainingShape3(self): + x_shape = [1, 2, 1, 6] + self._runtests(x_shape, is_training=True, gradient_test=True) + + @test_util.run_deprecated_v1 + def testBatchNormGradTrainingShape4(self): + x_shape = [5, 7, 11, 4] + self._runtests(x_shape, is_training=True, gradient_test=True) + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test never passed for XLA') + def testBatchNormGradTrainingShape5(self): + x_shape = [0, 7, 11, 4] + self._runtests(x_shape, is_training=True, gradient_test=True) def _testBatchNormGradGrad(self, config): shape = config['shape'] diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 4cb2caf6028..f8ff4d78e4b 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -883,7 +883,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad): } if version == 2: args["reserve_space_3"] = op.outputs[5] - return grad_fun(**args) + dx, dscale, doffset, _, _ = grad_fun(**args) else: pop_mean = op.inputs[3] pop_var = op.inputs[4] @@ -905,7 +905,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad): dx, dscale, doffset, _, _ = grad_fun(**args) if data_format == b"NCHW": dx = array_ops.transpose(dx, [0, 3, 1, 2]) - return dx, dscale, doffset, None, None + return dx, dscale, doffset, None, None + @ops.RegisterGradient("FusedBatchNorm") def _FusedBatchNormGrad(op, *grad): diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 8b60e1e274e..07ff8bfb4e8 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1432,7 +1432,8 @@ def fused_batch_norm( epsilon=0.001, data_format="NHWC", is_training=True, - name=None): + name=None, + exponential_avg_factor=1.0): r"""Batch normalization. @@ -1444,22 +1445,49 @@ def fused_batch_norm( x: Input `Tensor` of 4 dimensions. scale: A `Tensor` of 1 dimension for scaling. offset: A `Tensor` of 1 dimension for bias. - mean: A `Tensor` of 1 dimension for population mean used for inference. - variance: A `Tensor` of 1 dimension for population variance - used for inference. + mean: A `Tensor` of 1 dimension for population mean. The shape and meaning + of this argument depends on the value of is_training and + exponential_avg_factor as follows: + is_training==False (inference): + Mean must be a `Tensor` of the same shape as scale containing the + estimated population mean computed during training. + is_training==True and exponential_avg_factor == 1.0: + Mean must be None. + is_training==True and exponential_avg_factor != 1.0: + Mean must be a `Tensor` of the same shape as scale containing the + exponential running mean. + variance: A `Tensor` of 1 dimension for population variance. The shape and + meaning of this argument depends on the value of is_training and + exponential_avg_factor as follows: + is_training==False (inference): + Variance must be a `Tensor` of the same shape as scale containing + the estimated population variance computed during training. + is_training==True and exponential_avg_factor == 1.0: + Variance must be None. + is_training==True and exponential_avg_factor != 1.0: + Variance must be a `Tensor` of the same shape as scale containing + the exponential running variance. epsilon: A small float number added to the variance of x. data_format: The data format for x. Either "NHWC" (default) or "NCHW". is_training: A bool value to specify if the operation is used for training or inference. name: A name for this operation (optional). + exponential_avg_factor: A float number (usually between 0 and 1) used + for controlling the decay of the running + population average of mean and variance. + If set to 1.0, the current batch average is + returned. Returns: y: A 4D Tensor for the normalized, scaled, offsetted x. - batch_mean: A 1D Tensor for the mean of x. - batch_var: A 1D Tensor for the variance of x. - - Raises: - ValueError: If mean or variance is not None when is_training is True. + running_mean: A 1D Tensor for the exponential running mean of x. + The output value is (1 - exponential_avg_factor) * mean + + exponential_avg_factor * batch_mean), where batch_mean + is the mean of the current batch in x. + running_var: A 1D Tensor for the exponential running variance + The output value is (1 - exponential_avg_factor) * variance + + exponential_avg_factor * batch_variance), where batch_variance + is the variance of the current batch in x. References: Batch Normalization - Accelerating Deep Network Training by Reducing @@ -1467,24 +1495,44 @@ def fused_batch_norm( [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) """ + if is_training and exponential_avg_factor == 1.0: + if (mean is not None) or (variance is not None): + raise ValueError("Both 'mean' and 'variance' must be None when " + "is_training is True and " + "exponential_avg_factor == 1.0.") + else: + if (mean is None) or (variance is None): + raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when " + "is_training is False or " + "exponential_avg_factor != 1.0.") x = ops.convert_to_tensor(x, name="input") scale = ops.convert_to_tensor(scale, name="scale") offset = ops.convert_to_tensor(offset, name="offset") - if is_training: - if (mean is not None) or (variance is not None): - raise ValueError("Both 'mean' and 'variance' must be None " - "if is_training is True.") if mean is None: mean = constant_op.constant([]) if variance is None: variance = constant_op.constant([]) + # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to # prevent exception (see cudnn.h). min_epsilon = 1.001e-5 epsilon = epsilon if epsilon > min_epsilon else min_epsilon - if compat.forward_compatible(2019, 6, 6): - y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( + if compat.forward_compatible(2020, 3, 6): + y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( + x, + scale, + offset, + mean, + variance, + epsilon=epsilon, + exponential_avg_factor=exponential_avg_factor, + data_format=data_format, + is_training=is_training, + name=name) + return y, running_mean, running_var + else: + y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( x, scale, offset, @@ -1494,23 +1542,7 @@ def fused_batch_norm( data_format=data_format, is_training=is_training, name=name) - return y, batch_mean, batch_var - - if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16: - fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2 - else: - fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access - y, batch_mean, batch_var, _, _ = fused_batch_norm_func( - x, - scale, - offset, - mean, - variance, - epsilon=epsilon, - data_format=data_format, - is_training=is_training, - name=name) - return y, batch_mean, batch_var + return y, running_mean, running_var @tf_export(v1=["nn.batch_norm_with_global_normalization"]) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 130841dde5f..35260ad3d42 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -3607,8 +3607,10 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl( void* batch_mean_opaque; void* batch_var_opaque; if (!batch_mean->is_null() && !batch_var->is_null()) { - stream->ThenMemZero(batch_mean, batch_mean->size()); - stream->ThenMemZero(batch_var, batch_var->size()); + if (exponential_average_factor == 1.0) { + stream->ThenMemZero(batch_mean, batch_mean->size()); + stream->ThenMemZero(batch_var, batch_var->size()); + } batch_mean_opaque = batch_mean->opaque(); batch_var_opaque = batch_var->opaque(); } else { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt index 239872b111d..932e5037d99 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt @@ -210,7 +210,7 @@ tf_module { } member_method { name: "fused_batch_norm" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], " } member_method { name: "in_top_k"