Use native_rsqrt in MeanStdDevNormalization. It seems to have enough precision on tested hardware.
PiperOrigin-RevId: 329599240 Change-Id: Ib371d4703510738181bf537f9bb5054aeba7d07d
This commit is contained in:
parent
2c9ffb560c
commit
5b64fd1891
@ -135,7 +135,7 @@ std::string MeanStdDevNormalization::GetNormalizationCode() {
|
||||
const float sum_diff_sq = local_reduce(private_sum_diff_sq, tmp);
|
||||
// Calculate 1/stddev (with the 'regulazing constant' as in tensor_utils.cc)
|
||||
const float variance = sum_diff_sq / args.src_tensor.Channels();
|
||||
const float stddev_inv = rsqrt(variance + 1.0e-8f);
|
||||
const float stddev_inv = native_rsqrt(variance + 1.0e-8f);
|
||||
// Calculate (t-mean)/stddev for each element
|
||||
for (int S = get_local_id(0); S < args.src_tensor.Slices(); S += get_local_size(0)) {
|
||||
const float4 t = args.src_tensor.Read<float>(0, 0, S, B);
|
||||
|
Loading…
x
Reference in New Issue
Block a user