Fix argument buffer reuse for BatchNormGrad.
When scale_after_normalization is set, the kernel may not alias the gamma input with db output, as gamma is still needed. PiperOrigin-RevId: 237970041
This commit is contained in:
parent
74b961829e
commit
8c99eb5f8b
@ -127,8 +127,12 @@ class BatchNormGradOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{2}, 2, var.shape(), &dv));
|
{2}, 2, var.shape(), &dv));
|
||||||
Tensor* db = nullptr;
|
Tensor* db = nullptr;
|
||||||
|
if (scale_after_normalization_) {
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
|
||||||
|
} else {
|
||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{3}, 3, mean.shape(), &db));
|
{3}, 3, mean.shape(), &db));
|
||||||
|
}
|
||||||
Tensor* dg = nullptr;
|
Tensor* dg = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
|
OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
|
||||||
|
|
||||||
|
@ -3711,6 +3711,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
shard_count = 4,
|
shard_count = 4,
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
|
@ -206,7 +206,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
2)
|
2)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla("This test never passed for XLA")
|
|
||||||
def testBatchNormGradImpl(self):
|
def testBatchNormGradImpl(self):
|
||||||
x_shape = [7, 5, 4, 6]
|
x_shape = [7, 5, 4, 6]
|
||||||
param_shape = [6]
|
param_shape = [6]
|
||||||
|
Loading…
Reference in New Issue
Block a user