[ROCm] Fix for a bug in ROCm batchnorm implementation.

Fixing the bug also makes the unit-test `//tensorflow/python/keras/layers:normalization_test_gpu` pass, so removing the `no_rocm` tag from it as well.
This commit is contained in:
Deven Desai 2021-01-22 17:33:54 +00:00
parent e110426b1f
commit fe54e03d6a
2 changed files with 0 additions and 3 deletions
tensorflow
python/keras/layers
stream_executor/rocm

View File

@ -715,7 +715,6 @@ cuda_py_test(
python_version = "PY3",
shard_count = 4,
tags = [
"no_rocm",
"notsan",
],
xla_tags = [

View File

@ -3584,8 +3584,6 @@ bool MIOpenSupport::DoBatchNormalizationForwardImpl(
auto status = miopenStatusInvalidValue;
if (is_training) {
stream->ThenMemZero(batch_mean, batch_mean->size());
stream->ThenMemZero(batch_var, batch_var->size());
status = wrap::miopenBatchNormalizationForwardTraining(
miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),