Update int8 comparison kernel and test for negative values.

This commit is contained in:
Yongfeng Gu 2019-08-27 16:35:17 -04:00
parent 0696cc9a62
commit 25967fc3aa
2 changed files with 52 additions and 33 deletions

View File

@ -34,7 +34,7 @@ namespace gpu {
static constexpr double kTolerance = 0.1f;
// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length
// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
// buffer_length where the relative error does not exceed the passed
// rel_error_threshold. Write the number of mismatches into out parameter
// mismatch_count.
@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f;
//
// #include<cuda_fp16.h>
// extern "C" { // avoid name mangling
// __device__ float canonicalize(float input) {
// __device__ float canonicalize_fp16(float input) {
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
// // differences due to overflows.
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
// }
//
// __device__ float extract_int8(int pack) {
// // Extract the lower 8 bits from pack and convert it to float
// const unsigned int bit_mask = 0xff;
// unsigned int bits = pack & bit_mask;
// char* int8_ptr = (char*)&bits;
// return __int2float_rn(*int8_ptr);
// }
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f;
// if (idx >= buffer_length) return;
// float elem_a = __half2float(buffer_a[idx]);
// float elem_b = __half2float(buffer_b[idx]);
// elem_a = canonicalize(elem_a);
// elem_b = canonicalize(elem_b);
// elem_a = canonicalize_fp16(elem_a);
// elem_b = canonicalize_fp16(elem_b);
// if (isnan(elem_a) && isnan(elem_b)) return;
// float rel_error = abs(elem_a - elem_b)
// / (max(abs(elem_a), abs(elem_b)) + 1);
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// }
//
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f;
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// }
//
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
@ -102,19 +110,25 @@ static constexpr double kTolerance = 0.1f;
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// }
//
// __global__ void __xla_int8_comparison(int8* buffer_a, int8* buffer_b,
// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
// int* mismatch_count) {
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
// if (idx >= buffer_length) return;
// float elem_a = __int2float_rn(buffer_a[idx]);
// float elem_b = __int2float_rn(buffer_b[idx]);
// float rel_error = abs(elem_a - elem_b)
// / (max(abs(elem_a), abs(elem_b)) + 1);
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// int pack_a = buffer_a[idx];
// int pack_b = buffer_b[idx];
// for(int i = 0; i < 4; ++i) {
// float elem_a = extract_int8(pack_a);
// float elem_b = extract_int8(pack_b);
// float rel_error = abs(elem_a - elem_b)
// / (max(abs(elem_a), abs(elem_b)) + 1);
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// pack_a >>= 8;
// pack_b >>= 8;
// }
// }
// } // end extern declaration.
static const char* buffer_compare_ptx = R"(
@ -412,7 +426,7 @@ BB2_11:
{
.reg .pred %p<10>;
.reg .f32 %f<42>;
.reg .b32 %r<19>;
.reg .b32 %r<23>;
.reg .b64 %rd<12>;
@ -436,10 +450,10 @@ BB2_11:
cvta.to.global.u64 %rd10, %rd3;
add.s64 %rd11, %rd10, %rd8;
ld.global.u32 %r2, [%rd9];
and.b32 %r7, %r2, 255;
cvt.s32.s8 %r7, %r2;
cvt.rn.f32.s32 %f6, %r7;
ld.global.u32 %r3, [%rd11];
and.b32 %r8, %r3, 255;
cvt.s32.s8 %r8, %r3;
cvt.rn.f32.s32 %f7, %r8;
sub.f32 %f8, %f6, %f7;
abs.f32 %f9, %f8;
@ -459,10 +473,12 @@ BB3_3:
atom.global.add.u32 %r9, [%rd1], 1;
BB3_4:
bfe.u32 %r10, %r2, 8, 8;
cvt.rn.f32.s32 %f15, %r10;
bfe.u32 %r11, %r3, 8, 8;
cvt.rn.f32.s32 %f16, %r11;
shr.u32 %r10, %r3, 8;
shr.u32 %r11, %r2, 8;
cvt.s32.s8 %r12, %r11;
cvt.rn.f32.s32 %f15, %r12;
cvt.s32.s8 %r13, %r10;
cvt.rn.f32.s32 %f16, %r13;
sub.f32 %f17, %f15, %f16;
abs.f32 %f18, %f17;
abs.f32 %f19, %f15;
@ -478,13 +494,15 @@ BB3_4:
@%p5 bra BB3_7;
BB3_6:
atom.global.add.u32 %r12, [%rd1], 1;
atom.global.add.u32 %r14, [%rd1], 1;
BB3_7:
bfe.u32 %r13, %r2, 16, 8;
cvt.rn.f32.s32 %f24, %r13;
bfe.u32 %r14, %r3, 16, 8;
cvt.rn.f32.s32 %f25, %r14;
shr.u32 %r15, %r3, 16;
shr.u32 %r16, %r2, 16;
cvt.s32.s8 %r17, %r16;
cvt.rn.f32.s32 %f24, %r17;
cvt.s32.s8 %r18, %r15;
cvt.rn.f32.s32 %f25, %r18;
sub.f32 %f26, %f24, %f25;
abs.f32 %f27, %f26;
abs.f32 %f28, %f24;
@ -500,13 +518,13 @@ BB3_7:
@%p7 bra BB3_10;
BB3_9:
atom.global.add.u32 %r15, [%rd1], 1;
atom.global.add.u32 %r19, [%rd1], 1;
BB3_10:
shr.u32 %r16, %r3, 24;
shr.u32 %r17, %r2, 24;
cvt.rn.f32.s32 %f33, %r17;
cvt.rn.f32.s32 %f34, %r16;
shr.s32 %r20, %r2, 24;
cvt.rn.f32.s32 %f33, %r20;
shr.s32 %r21, %r3, 24;
cvt.rn.f32.s32 %f34, %r21;
sub.f32 %f35, %f33, %f34;
abs.f32 %f36, %f35;
abs.f32 %f37, %f33;
@ -522,7 +540,7 @@ BB3_10:
@%p9 bra BB3_13;
BB3_12:
atom.global.add.u32 %r18, [%rd1], 1;
atom.global.add.u32 %r22, [%rd1], 1;
BB3_13:
ret;

View File

@ -184,6 +184,7 @@ TEST_F(BufferComparatorTest, TestNumbers) {
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({9}, {10}));
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({90}, {100}));
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({100}, {90}));
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({-128}, {127}));
}
TEST_F(BufferComparatorTest, TestMultiple) {