Fix argument buffer reuse for BatchNormGrad.

When scale_after_normalization is set, the kernel may not alias the gamma
input with db output, as gamma is still needed.

PiperOrigin-RevId: 237970041
This commit is contained in:
A. Unique TensorFlower 2019-03-12 01:30:41 -07:00 committed by TensorFlower Gardener
parent 74b961829e
commit 8c99eb5f8b
3 changed files with 7 additions and 3 deletions

View File

@ -127,8 +127,12 @@ class BatchNormGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{2}, 2, var.shape(), &dv)); {2}, 2, var.shape(), &dv));
Tensor* db = nullptr; Tensor* db = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( if (scale_after_normalization_) {
{3}, 3, mean.shape(), &db)); OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
} else {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{3}, 3, mean.shape(), &db));
}
Tensor* dg = nullptr; Tensor* dg = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));

View File

@ -3711,6 +3711,7 @@ cuda_py_test(
], ],
shard_count = 4, shard_count = 4,
tags = ["no_windows"], tags = ["no_windows"],
xla_enable_strict_auto_jit = True,
) )
cuda_py_test( cuda_py_test(

View File

@ -206,7 +206,6 @@ class BatchNormalizationTest(test.TestCase):
2) 2)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla("This test never passed for XLA")
def testBatchNormGradImpl(self): def testBatchNormGradImpl(self):
x_shape = [7, 5, 4, 6] x_shape = [7, 5, 4, 6]
param_shape = [6] param_shape = [6]