[ROCm] Fix for a bug in ROCm batchnorm implementation.
Fixing the bug also makes the unit-test `//tensorflow/python/keras/layers:normalization_test_gpu` pass, so removing the `no_rocm` tag from it as well.
This commit is contained in:
parent
e110426b1f
commit
fe54e03d6a
tensorflow
@ -715,7 +715,6 @@ cuda_py_test(
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"notsan",
|
||||
],
|
||||
xla_tags = [
|
||||
|
@ -3584,8 +3584,6 @@ bool MIOpenSupport::DoBatchNormalizationForwardImpl(
|
||||
|
||||
auto status = miopenStatusInvalidValue;
|
||||
if (is_training) {
|
||||
stream->ThenMemZero(batch_mean, batch_mean->size());
|
||||
stream->ThenMemZero(batch_var, batch_var->size());
|
||||
status = wrap::miopenBatchNormalizationForwardTraining(
|
||||
miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
|
||||
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
|
||||
|
Loading…
Reference in New Issue
Block a user