Uniquify function names further.
This commit is contained in:
parent
25967fc3aa
commit
6e14ed49f6
@ -46,13 +46,13 @@ static constexpr double kTolerance = 0.1f;
|
||||
//
|
||||
// #include<cuda_fp16.h>
|
||||
// extern "C" { // avoid name mangling
|
||||
// __device__ float canonicalize_fp16(float input) {
|
||||
// __device__ float __xla_buffer_comparator_canonicalize(float input) {
|
||||
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
|
||||
// // differences due to overflows.
|
||||
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
|
||||
// }
|
||||
|
||||
// __device__ float extract_int8(int pack) {
|
||||
// __device__ float __xla_buffer_comparator_extract_int8(int pack) {
|
||||
// // Extract the lower 8 bits from pack and convert it to float
|
||||
// const unsigned int bit_mask = 0xff;
|
||||
// unsigned int bits = pack & bit_mask;
|
||||
@ -68,8 +68,8 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (idx >= buffer_length) return;
|
||||
// float elem_a = __half2float(buffer_a[idx]);
|
||||
// float elem_b = __half2float(buffer_b[idx]);
|
||||
// elem_a = canonicalize_fp16(elem_a);
|
||||
// elem_b = canonicalize_fp16(elem_b);
|
||||
// elem_a = __xla_buffer_comparator_canonicalize(elem_a);
|
||||
// elem_b = __xla_buffer_comparator_canonicalize(elem_b);
|
||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
@ -120,8 +120,8 @@ static constexpr double kTolerance = 0.1f;
|
||||
// int pack_a = buffer_a[idx];
|
||||
// int pack_b = buffer_b[idx];
|
||||
// for(int i = 0; i < 4; ++i) {
|
||||
// float elem_a = extract_int8(pack_a);
|
||||
// float elem_b = extract_int8(pack_b);
|
||||
// float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
|
||||
// float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
|
Loading…
x
Reference in New Issue
Block a user