From 8c99eb5f8b3b9ad44d79c33207759a0b114e86b9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Mar 2019 01:30:41 -0700 Subject: [PATCH] Fix argument buffer reuse for BatchNormGrad. When scale_after_normalization is set, the kernel may not alias the gamma input with db output, as gamma is still needed. PiperOrigin-RevId: 237970041 --- tensorflow/core/kernels/batch_norm_op.cc | 8 ++++++-- tensorflow/python/BUILD | 1 + tensorflow/python/ops/nn_batchnorm_test.py | 1 - 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc index c34ea14bf60..609ddd68caf 100644 --- a/tensorflow/core/kernels/batch_norm_op.cc +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -127,8 +127,12 @@ class BatchNormGradOp : public OpKernel { OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {2}, 2, var.shape(), &dv)); Tensor* db = nullptr; - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {3}, 3, mean.shape(), &db)); + if (scale_after_normalization_) { + OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db)); + } else { + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {3}, 3, mean.shape(), &db)); + } Tensor* dg = nullptr; OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3ae7b8f315d..b4080e95a63 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3711,6 +3711,7 @@ cuda_py_test( ], shard_count = 4, tags = ["no_windows"], + xla_enable_strict_auto_jit = True, ) cuda_py_test( diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py index fedf8e44c3d..e978f1d3260 100644 --- a/tensorflow/python/ops/nn_batchnorm_test.py +++ b/tensorflow/python/ops/nn_batchnorm_test.py @@ -206,7 +206,6 @@ class BatchNormalizationTest(test.TestCase): 2) @test_util.run_deprecated_v1 - @test_util.disable_xla("This test never passed for XLA") def testBatchNormGradImpl(self): x_shape = [7, 5, 4, 6] param_shape = [6]