Update int8 comparison kernel and test for negative values.
This commit is contained in:
parent
0696cc9a62
commit
25967fc3aa
@ -34,7 +34,7 @@ namespace gpu {
|
||||
|
||||
static constexpr double kTolerance = 0.1f;
|
||||
|
||||
// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length
|
||||
// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
|
||||
// buffer_length where the relative error does not exceed the passed
|
||||
// rel_error_threshold. Write the number of mismatches into out parameter
|
||||
// mismatch_count.
|
||||
@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f;
|
||||
//
|
||||
// #include<cuda_fp16.h>
|
||||
// extern "C" { // avoid name mangling
|
||||
// __device__ float canonicalize(float input) {
|
||||
// __device__ float canonicalize_fp16(float input) {
|
||||
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
|
||||
// // differences due to overflows.
|
||||
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
|
||||
// }
|
||||
//
|
||||
|
||||
// __device__ float extract_int8(int pack) {
|
||||
// // Extract the lower 8 bits from pack and convert it to float
|
||||
// const unsigned int bit_mask = 0xff;
|
||||
// unsigned int bits = pack & bit_mask;
|
||||
// char* int8_ptr = (char*)&bits;
|
||||
// return __int2float_rn(*int8_ptr);
|
||||
// }
|
||||
|
||||
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (idx >= buffer_length) return;
|
||||
// float elem_a = __half2float(buffer_a[idx]);
|
||||
// float elem_b = __half2float(buffer_b[idx]);
|
||||
// elem_a = canonicalize(elem_a);
|
||||
// elem_b = canonicalize(elem_b);
|
||||
// elem_a = canonicalize_fp16(elem_a);
|
||||
// elem_b = canonicalize_fp16(elem_b);
|
||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
|
||||
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
|
||||
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
@ -102,19 +110,25 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
// __global__ void __xla_int8_comparison(int8* buffer_a, int8* buffer_b,
|
||||
|
||||
// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
// int* mismatch_count) {
|
||||
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// if (idx >= buffer_length) return;
|
||||
// float elem_a = __int2float_rn(buffer_a[idx]);
|
||||
// float elem_b = __int2float_rn(buffer_b[idx]);
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// int pack_a = buffer_a[idx];
|
||||
// int pack_b = buffer_b[idx];
|
||||
// for(int i = 0; i < 4; ++i) {
|
||||
// float elem_a = extract_int8(pack_a);
|
||||
// float elem_b = extract_int8(pack_b);
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// pack_a >>= 8;
|
||||
// pack_b >>= 8;
|
||||
// }
|
||||
// }
|
||||
// } // end extern declaration.
|
||||
static const char* buffer_compare_ptx = R"(
|
||||
@ -412,7 +426,7 @@ BB2_11:
|
||||
{
|
||||
.reg .pred %p<10>;
|
||||
.reg .f32 %f<42>;
|
||||
.reg .b32 %r<19>;
|
||||
.reg .b32 %r<23>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
|
||||
@ -436,10 +450,10 @@ BB2_11:
|
||||
cvta.to.global.u64 %rd10, %rd3;
|
||||
add.s64 %rd11, %rd10, %rd8;
|
||||
ld.global.u32 %r2, [%rd9];
|
||||
and.b32 %r7, %r2, 255;
|
||||
cvt.s32.s8 %r7, %r2;
|
||||
cvt.rn.f32.s32 %f6, %r7;
|
||||
ld.global.u32 %r3, [%rd11];
|
||||
and.b32 %r8, %r3, 255;
|
||||
cvt.s32.s8 %r8, %r3;
|
||||
cvt.rn.f32.s32 %f7, %r8;
|
||||
sub.f32 %f8, %f6, %f7;
|
||||
abs.f32 %f9, %f8;
|
||||
@ -459,10 +473,12 @@ BB3_3:
|
||||
atom.global.add.u32 %r9, [%rd1], 1;
|
||||
|
||||
BB3_4:
|
||||
bfe.u32 %r10, %r2, 8, 8;
|
||||
cvt.rn.f32.s32 %f15, %r10;
|
||||
bfe.u32 %r11, %r3, 8, 8;
|
||||
cvt.rn.f32.s32 %f16, %r11;
|
||||
shr.u32 %r10, %r3, 8;
|
||||
shr.u32 %r11, %r2, 8;
|
||||
cvt.s32.s8 %r12, %r11;
|
||||
cvt.rn.f32.s32 %f15, %r12;
|
||||
cvt.s32.s8 %r13, %r10;
|
||||
cvt.rn.f32.s32 %f16, %r13;
|
||||
sub.f32 %f17, %f15, %f16;
|
||||
abs.f32 %f18, %f17;
|
||||
abs.f32 %f19, %f15;
|
||||
@ -478,13 +494,15 @@ BB3_4:
|
||||
@%p5 bra BB3_7;
|
||||
|
||||
BB3_6:
|
||||
atom.global.add.u32 %r12, [%rd1], 1;
|
||||
atom.global.add.u32 %r14, [%rd1], 1;
|
||||
|
||||
BB3_7:
|
||||
bfe.u32 %r13, %r2, 16, 8;
|
||||
cvt.rn.f32.s32 %f24, %r13;
|
||||
bfe.u32 %r14, %r3, 16, 8;
|
||||
cvt.rn.f32.s32 %f25, %r14;
|
||||
shr.u32 %r15, %r3, 16;
|
||||
shr.u32 %r16, %r2, 16;
|
||||
cvt.s32.s8 %r17, %r16;
|
||||
cvt.rn.f32.s32 %f24, %r17;
|
||||
cvt.s32.s8 %r18, %r15;
|
||||
cvt.rn.f32.s32 %f25, %r18;
|
||||
sub.f32 %f26, %f24, %f25;
|
||||
abs.f32 %f27, %f26;
|
||||
abs.f32 %f28, %f24;
|
||||
@ -500,13 +518,13 @@ BB3_7:
|
||||
@%p7 bra BB3_10;
|
||||
|
||||
BB3_9:
|
||||
atom.global.add.u32 %r15, [%rd1], 1;
|
||||
atom.global.add.u32 %r19, [%rd1], 1;
|
||||
|
||||
BB3_10:
|
||||
shr.u32 %r16, %r3, 24;
|
||||
shr.u32 %r17, %r2, 24;
|
||||
cvt.rn.f32.s32 %f33, %r17;
|
||||
cvt.rn.f32.s32 %f34, %r16;
|
||||
shr.s32 %r20, %r2, 24;
|
||||
cvt.rn.f32.s32 %f33, %r20;
|
||||
shr.s32 %r21, %r3, 24;
|
||||
cvt.rn.f32.s32 %f34, %r21;
|
||||
sub.f32 %f35, %f33, %f34;
|
||||
abs.f32 %f36, %f35;
|
||||
abs.f32 %f37, %f33;
|
||||
@ -522,7 +540,7 @@ BB3_10:
|
||||
@%p9 bra BB3_13;
|
||||
|
||||
BB3_12:
|
||||
atom.global.add.u32 %r18, [%rd1], 1;
|
||||
atom.global.add.u32 %r22, [%rd1], 1;
|
||||
|
||||
BB3_13:
|
||||
ret;
|
||||
|
@ -184,6 +184,7 @@ TEST_F(BufferComparatorTest, TestNumbers) {
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({9}, {10}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({90}, {100}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({100}, {90}));
|
||||
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({-128}, {127}));
|
||||
}
|
||||
|
||||
TEST_F(BufferComparatorTest, TestMultiple) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user