Use native_rsqrt in MeanStdDevNormalization. It seems to have enough precision on tested hardware.

PiperOrigin-RevId: 329599240
Change-Id: Ib371d4703510738181bf537f9bb5054aeba7d07d
This commit is contained in:
Robert David 2020-09-01 16:00:14 -07:00 committed by TensorFlower Gardener
parent 2c9ffb560c
commit 5b64fd1891

View File

@ -135,7 +135,7 @@ std::string MeanStdDevNormalization::GetNormalizationCode() {
const float sum_diff_sq = local_reduce(private_sum_diff_sq, tmp);
// Calculate 1/stddev (with the 'regulazing constant' as in tensor_utils.cc)
const float variance = sum_diff_sq / args.src_tensor.Channels();
const float stddev_inv = rsqrt(variance + 1.0e-8f);
const float stddev_inv = native_rsqrt(variance + 1.0e-8f);
// Calculate (t-mean)/stddev for each element
for (int S = get_local_id(0); S < args.src_tensor.Slices(); S += get_local_size(0)) {
const float4 t = args.src_tensor.Read<float>(0, 0, S, B);