Update int8 comparison kernel and test for negative values.
This commit is contained in:
parent
0696cc9a62
commit
25967fc3aa
@ -34,7 +34,7 @@ namespace gpu {
|
|||||||
|
|
||||||
static constexpr double kTolerance = 0.1f;
|
static constexpr double kTolerance = 0.1f;
|
||||||
|
|
||||||
// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length
|
// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
|
||||||
// buffer_length where the relative error does not exceed the passed
|
// buffer_length where the relative error does not exceed the passed
|
||||||
// rel_error_threshold. Write the number of mismatches into out parameter
|
// rel_error_threshold. Write the number of mismatches into out parameter
|
||||||
// mismatch_count.
|
// mismatch_count.
|
||||||
@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f;
|
|||||||
//
|
//
|
||||||
// #include<cuda_fp16.h>
|
// #include<cuda_fp16.h>
|
||||||
// extern "C" { // avoid name mangling
|
// extern "C" { // avoid name mangling
|
||||||
// __device__ float canonicalize(float input) {
|
// __device__ float canonicalize_fp16(float input) {
|
||||||
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
|
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
|
||||||
// // differences due to overflows.
|
// // differences due to overflows.
|
||||||
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
|
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
|
// __device__ float extract_int8(int pack) {
|
||||||
|
// // Extract the lower 8 bits from pack and convert it to float
|
||||||
|
// const unsigned int bit_mask = 0xff;
|
||||||
|
// unsigned int bits = pack & bit_mask;
|
||||||
|
// char* int8_ptr = (char*)&bits;
|
||||||
|
// return __int2float_rn(*int8_ptr);
|
||||||
|
// }
|
||||||
|
|
||||||
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
|
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
|
||||||
// float rel_error_threshold,
|
// float rel_error_threshold,
|
||||||
// unsigned long long buffer_length,
|
// unsigned long long buffer_length,
|
||||||
@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f;
|
|||||||
// if (idx >= buffer_length) return;
|
// if (idx >= buffer_length) return;
|
||||||
// float elem_a = __half2float(buffer_a[idx]);
|
// float elem_a = __half2float(buffer_a[idx]);
|
||||||
// float elem_b = __half2float(buffer_b[idx]);
|
// float elem_b = __half2float(buffer_b[idx]);
|
||||||
// elem_a = canonicalize(elem_a);
|
// elem_a = canonicalize_fp16(elem_a);
|
||||||
// elem_b = canonicalize(elem_b);
|
// elem_b = canonicalize_fp16(elem_b);
|
||||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||||
// float rel_error = abs(elem_a - elem_b)
|
// float rel_error = abs(elem_a - elem_b)
|
||||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||||
// atomicAdd(mismatch_count, 1);
|
// atomicAdd(mismatch_count, 1);
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
|
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
|
||||||
// float rel_error_threshold,
|
// float rel_error_threshold,
|
||||||
// unsigned long long buffer_length,
|
// unsigned long long buffer_length,
|
||||||
@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f;
|
|||||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||||
// atomicAdd(mismatch_count, 1);
|
// atomicAdd(mismatch_count, 1);
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
|
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
|
||||||
// float rel_error_threshold,
|
// float rel_error_threshold,
|
||||||
// unsigned long long buffer_length,
|
// unsigned long long buffer_length,
|
||||||
@ -102,19 +110,25 @@ static constexpr double kTolerance = 0.1f;
|
|||||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||||
// atomicAdd(mismatch_count, 1);
|
// atomicAdd(mismatch_count, 1);
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// __global__ void __xla_int8_comparison(int8* buffer_a, int8* buffer_b,
|
// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
|
||||||
// float rel_error_threshold,
|
// float rel_error_threshold,
|
||||||
// unsigned long long buffer_length,
|
// unsigned long long buffer_length,
|
||||||
// int* mismatch_count) {
|
// int* mismatch_count) {
|
||||||
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
// if (idx >= buffer_length) return;
|
// if (idx >= buffer_length) return;
|
||||||
// float elem_a = __int2float_rn(buffer_a[idx]);
|
// int pack_a = buffer_a[idx];
|
||||||
// float elem_b = __int2float_rn(buffer_b[idx]);
|
// int pack_b = buffer_b[idx];
|
||||||
// float rel_error = abs(elem_a - elem_b)
|
// for(int i = 0; i < 4; ++i) {
|
||||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
// float elem_a = extract_int8(pack_a);
|
||||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
// float elem_b = extract_int8(pack_b);
|
||||||
// atomicAdd(mismatch_count, 1);
|
// float rel_error = abs(elem_a - elem_b)
|
||||||
|
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||||
|
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||||
|
// atomicAdd(mismatch_count, 1);
|
||||||
|
// pack_a >>= 8;
|
||||||
|
// pack_b >>= 8;
|
||||||
|
// }
|
||||||
// }
|
// }
|
||||||
// } // end extern declaration.
|
// } // end extern declaration.
|
||||||
static const char* buffer_compare_ptx = R"(
|
static const char* buffer_compare_ptx = R"(
|
||||||
@ -412,7 +426,7 @@ BB2_11:
|
|||||||
{
|
{
|
||||||
.reg .pred %p<10>;
|
.reg .pred %p<10>;
|
||||||
.reg .f32 %f<42>;
|
.reg .f32 %f<42>;
|
||||||
.reg .b32 %r<19>;
|
.reg .b32 %r<23>;
|
||||||
.reg .b64 %rd<12>;
|
.reg .b64 %rd<12>;
|
||||||
|
|
||||||
|
|
||||||
@ -436,10 +450,10 @@ BB2_11:
|
|||||||
cvta.to.global.u64 %rd10, %rd3;
|
cvta.to.global.u64 %rd10, %rd3;
|
||||||
add.s64 %rd11, %rd10, %rd8;
|
add.s64 %rd11, %rd10, %rd8;
|
||||||
ld.global.u32 %r2, [%rd9];
|
ld.global.u32 %r2, [%rd9];
|
||||||
and.b32 %r7, %r2, 255;
|
cvt.s32.s8 %r7, %r2;
|
||||||
cvt.rn.f32.s32 %f6, %r7;
|
cvt.rn.f32.s32 %f6, %r7;
|
||||||
ld.global.u32 %r3, [%rd11];
|
ld.global.u32 %r3, [%rd11];
|
||||||
and.b32 %r8, %r3, 255;
|
cvt.s32.s8 %r8, %r3;
|
||||||
cvt.rn.f32.s32 %f7, %r8;
|
cvt.rn.f32.s32 %f7, %r8;
|
||||||
sub.f32 %f8, %f6, %f7;
|
sub.f32 %f8, %f6, %f7;
|
||||||
abs.f32 %f9, %f8;
|
abs.f32 %f9, %f8;
|
||||||
@ -459,10 +473,12 @@ BB3_3:
|
|||||||
atom.global.add.u32 %r9, [%rd1], 1;
|
atom.global.add.u32 %r9, [%rd1], 1;
|
||||||
|
|
||||||
BB3_4:
|
BB3_4:
|
||||||
bfe.u32 %r10, %r2, 8, 8;
|
shr.u32 %r10, %r3, 8;
|
||||||
cvt.rn.f32.s32 %f15, %r10;
|
shr.u32 %r11, %r2, 8;
|
||||||
bfe.u32 %r11, %r3, 8, 8;
|
cvt.s32.s8 %r12, %r11;
|
||||||
cvt.rn.f32.s32 %f16, %r11;
|
cvt.rn.f32.s32 %f15, %r12;
|
||||||
|
cvt.s32.s8 %r13, %r10;
|
||||||
|
cvt.rn.f32.s32 %f16, %r13;
|
||||||
sub.f32 %f17, %f15, %f16;
|
sub.f32 %f17, %f15, %f16;
|
||||||
abs.f32 %f18, %f17;
|
abs.f32 %f18, %f17;
|
||||||
abs.f32 %f19, %f15;
|
abs.f32 %f19, %f15;
|
||||||
@ -478,13 +494,15 @@ BB3_4:
|
|||||||
@%p5 bra BB3_7;
|
@%p5 bra BB3_7;
|
||||||
|
|
||||||
BB3_6:
|
BB3_6:
|
||||||
atom.global.add.u32 %r12, [%rd1], 1;
|
atom.global.add.u32 %r14, [%rd1], 1;
|
||||||
|
|
||||||
BB3_7:
|
BB3_7:
|
||||||
bfe.u32 %r13, %r2, 16, 8;
|
shr.u32 %r15, %r3, 16;
|
||||||
cvt.rn.f32.s32 %f24, %r13;
|
shr.u32 %r16, %r2, 16;
|
||||||
bfe.u32 %r14, %r3, 16, 8;
|
cvt.s32.s8 %r17, %r16;
|
||||||
cvt.rn.f32.s32 %f25, %r14;
|
cvt.rn.f32.s32 %f24, %r17;
|
||||||
|
cvt.s32.s8 %r18, %r15;
|
||||||
|
cvt.rn.f32.s32 %f25, %r18;
|
||||||
sub.f32 %f26, %f24, %f25;
|
sub.f32 %f26, %f24, %f25;
|
||||||
abs.f32 %f27, %f26;
|
abs.f32 %f27, %f26;
|
||||||
abs.f32 %f28, %f24;
|
abs.f32 %f28, %f24;
|
||||||
@ -500,13 +518,13 @@ BB3_7:
|
|||||||
@%p7 bra BB3_10;
|
@%p7 bra BB3_10;
|
||||||
|
|
||||||
BB3_9:
|
BB3_9:
|
||||||
atom.global.add.u32 %r15, [%rd1], 1;
|
atom.global.add.u32 %r19, [%rd1], 1;
|
||||||
|
|
||||||
BB3_10:
|
BB3_10:
|
||||||
shr.u32 %r16, %r3, 24;
|
shr.s32 %r20, %r2, 24;
|
||||||
shr.u32 %r17, %r2, 24;
|
cvt.rn.f32.s32 %f33, %r20;
|
||||||
cvt.rn.f32.s32 %f33, %r17;
|
shr.s32 %r21, %r3, 24;
|
||||||
cvt.rn.f32.s32 %f34, %r16;
|
cvt.rn.f32.s32 %f34, %r21;
|
||||||
sub.f32 %f35, %f33, %f34;
|
sub.f32 %f35, %f33, %f34;
|
||||||
abs.f32 %f36, %f35;
|
abs.f32 %f36, %f35;
|
||||||
abs.f32 %f37, %f33;
|
abs.f32 %f37, %f33;
|
||||||
@ -522,7 +540,7 @@ BB3_10:
|
|||||||
@%p9 bra BB3_13;
|
@%p9 bra BB3_13;
|
||||||
|
|
||||||
BB3_12:
|
BB3_12:
|
||||||
atom.global.add.u32 %r18, [%rd1], 1;
|
atom.global.add.u32 %r22, [%rd1], 1;
|
||||||
|
|
||||||
BB3_13:
|
BB3_13:
|
||||||
ret;
|
ret;
|
||||||
|
@ -184,6 +184,7 @@ TEST_F(BufferComparatorTest, TestNumbers) {
|
|||||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({9}, {10}));
|
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({9}, {10}));
|
||||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({90}, {100}));
|
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({90}, {100}));
|
||||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({100}, {90}));
|
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({100}, {90}));
|
||||||
|
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({-128}, {127}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BufferComparatorTest, TestMultiple) {
|
TEST_F(BufferComparatorTest, TestMultiple) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user