Merge pull request #31988 from yongfeng-nv:xla_int8_buffer_comparator
PiperOrigin-RevId: 265890572
This commit is contained in:
commit
c4fc64c728
@ -34,7 +34,7 @@ namespace gpu {
|
||||
|
||||
static constexpr double kTolerance = 0.1f;
|
||||
|
||||
// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length
|
||||
// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
|
||||
// buffer_length where the relative error does not exceed the passed
|
||||
// rel_error_threshold. Write the number of mismatches into out parameter
|
||||
// mismatch_count.
|
||||
@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f;
|
||||
//
|
||||
// #include<cuda_fp16.h>
|
||||
// extern "C" { // avoid name mangling
|
||||
// __device__ float canonicalize(float input) {
|
||||
// __device__ float __xla_buffer_comparator_canonicalize(float input) {
|
||||
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
|
||||
// // differences due to overflows.
|
||||
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
|
||||
// }
|
||||
//
|
||||
|
||||
// __device__ float __xla_buffer_comparator_extract_int8(int pack) {
|
||||
// // Extract the lower 8 bits from pack and convert it to float
|
||||
// const unsigned int bit_mask = 0xff;
|
||||
// unsigned int bits = pack & bit_mask;
|
||||
// char* int8_ptr = (char*)&bits;
|
||||
// return __int2float_rn(*int8_ptr);
|
||||
// }
|
||||
|
||||
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (idx >= buffer_length) return;
|
||||
// float elem_a = __half2float(buffer_a[idx]);
|
||||
// float elem_b = __half2float(buffer_b[idx]);
|
||||
// elem_a = canonicalize(elem_a);
|
||||
// elem_b = canonicalize(elem_b);
|
||||
// elem_a = __xla_buffer_comparator_canonicalize(elem_a);
|
||||
// elem_b = __xla_buffer_comparator_canonicalize(elem_b);
|
||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
|
||||
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
|
||||
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
@ -102,234 +110,440 @@ static constexpr double kTolerance = 0.1f;
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
|
||||
// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
// int* mismatch_count) {
|
||||
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// if (idx >= buffer_length) return;
|
||||
// int pack_a = buffer_a[idx];
|
||||
// int pack_b = buffer_b[idx];
|
||||
// for(int i = 0; i < 4; ++i) {
|
||||
// float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
|
||||
// float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// pack_a >>= 8;
|
||||
// pack_b >>= 8;
|
||||
// }
|
||||
// }
|
||||
// } // end extern declaration.
|
||||
static const char* buffer_compare_ptx = R"(
|
||||
.version 4.2
|
||||
.version 6.4
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
// .globl __xla_fp16_comparison
|
||||
|
||||
.visible .entry __xla_fp16_comparison(
|
||||
.param .u64 __xla_fp16_comparison_param_0,
|
||||
.param .u64 __xla_fp16_comparison_param_1,
|
||||
.param .f32 __xla_fp16_comparison_param_2,
|
||||
.param .u64 __xla_fp16_comparison_param_3,
|
||||
.param .u64 __xla_fp16_comparison_param_4
|
||||
.param .u64 __xla_fp16_comparison_param_0,
|
||||
.param .u64 __xla_fp16_comparison_param_1,
|
||||
.param .f32 __xla_fp16_comparison_param_2,
|
||||
.param .u64 __xla_fp16_comparison_param_3,
|
||||
.param .u64 __xla_fp16_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<10>;
|
||||
.reg .b16 %rs<3>;
|
||||
.reg .f32 %f<20>;
|
||||
.reg .b32 %r<6>;
|
||||
.reg .b64 %rd<12>;
|
||||
ld.param.u64 %rd8, [__xla_fp16_comparison_param_3];
|
||||
mov.u32 %r1, %tid.x;
|
||||
mov.u32 %r2, %ctaid.x;
|
||||
mov.u32 %r3, %ntid.x;
|
||||
mad.lo.s32 %r4, %r3, %r2, %r1;
|
||||
cvt.s64.s32 %rd4, %r4;
|
||||
setp.ge.u64 %p1, %rd4, %rd8;
|
||||
@%p1 bra LBB7_4;
|
||||
ld.param.u64 %rd5, [__xla_fp16_comparison_param_0];
|
||||
ld.param.u64 %rd7, [__xla_fp16_comparison_param_1];
|
||||
cvta.to.global.u64 %rd2, %rd7;
|
||||
cvta.to.global.u64 %rd3, %rd5;
|
||||
shl.b64 %rd9, %rd4, 1;
|
||||
add.s64 %rd10, %rd3, %rd9;
|
||||
ld.global.u16 %rs1, [%rd10];
|
||||
// begin inline asm
|
||||
{ cvt.f32.f16 %f6, %rs1;}
|
||||
.reg .pred %p<9>;
|
||||
.reg .b16 %rs<3>;
|
||||
.reg .f32 %f<28>;
|
||||
.reg .b32 %r<6>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
// end inline asm
|
||||
add.s64 %rd11, %rd2, %rd9;
|
||||
ld.global.u16 %rs2, [%rd11];
|
||||
// begin inline asm
|
||||
{ cvt.f32.f16 %f7, %rs2;}
|
||||
|
||||
// end inline asm
|
||||
abs.f32 %f8, %f6;
|
||||
setp.gtu.f32 %p2, %f8, 0f7F800000;
|
||||
min.f32 %f9, %f6, 0f477FE100;
|
||||
max.f32 %f10, %f9, 0fC77FE100;
|
||||
selp.f32 %f1, %f6, %f10, %p2;
|
||||
abs.f32 %f11, %f7;
|
||||
setp.gtu.f32 %p3, %f11, 0f7F800000;
|
||||
min.f32 %f12, %f7, 0f477FE100;
|
||||
max.f32 %f13, %f12, 0fC77FE100;
|
||||
selp.f32 %f2, %f7, %f13, %p3;
|
||||
abs.f32 %f3, %f1;
|
||||
setp.gtu.f32 %p4, %f3, 0f7F800000;
|
||||
abs.f32 %f4, %f2;
|
||||
setp.gtu.f32 %p5, %f4, 0f7F800000;
|
||||
and.pred %p6, %p4, %p5;
|
||||
@%p6 bra LBB7_4;
|
||||
ld.param.f32 %f5, [__xla_fp16_comparison_param_2];
|
||||
sub.f32 %f14, %f1, %f2;
|
||||
abs.f32 %f15, %f14;
|
||||
max.f32 %f16, %f3, %f4;
|
||||
add.f32 %f17, %f16, 0f3F800000;
|
||||
div.rn.f32 %f18, %f15, %f17;
|
||||
setp.leu.f32 %p7, %f18, %f5;
|
||||
abs.f32 %f19, %f18;
|
||||
setp.le.f32 %p8, %f19, 0f7F800000;
|
||||
and.pred %p9, %p7, %p8;
|
||||
@%p9 bra LBB7_4;
|
||||
ld.param.u64 %rd6, [__xla_fp16_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd6;
|
||||
atom.global.add.u32 %r5, [%rd1], 1;
|
||||
LBB7_4:
|
||||
ret;
|
||||
ld.param.u64 %rd1, [__xla_fp16_comparison_param_0];
|
||||
ld.param.u64 %rd2, [__xla_fp16_comparison_param_1];
|
||||
ld.param.f32 %f10, [__xla_fp16_comparison_param_2];
|
||||
ld.param.u64 %rd4, [__xla_fp16_comparison_param_3];
|
||||
ld.param.u64 %rd3, [__xla_fp16_comparison_param_4];
|
||||
mov.u32 %r2, %ntid.x;
|
||||
mov.u32 %r3, %ctaid.x;
|
||||
mov.u32 %r4, %tid.x;
|
||||
mad.lo.s32 %r1, %r2, %r3, %r4;
|
||||
cvt.s64.s32 %rd5, %r1;
|
||||
setp.ge.u64 %p1, %rd5, %rd4;
|
||||
@%p1 bra BB0_9;
|
||||
|
||||
cvta.to.global.u64 %rd6, %rd1;
|
||||
mul.wide.s32 %rd7, %r1, 2;
|
||||
add.s64 %rd8, %rd6, %rd7;
|
||||
ld.global.u16 %rs1, [%rd8];
|
||||
// inline asm
|
||||
{ cvt.f32.f16 %f26, %rs1;}
|
||||
|
||||
// inline asm
|
||||
cvta.to.global.u64 %rd9, %rd2;
|
||||
add.s64 %rd10, %rd9, %rd7;
|
||||
ld.global.u16 %rs2, [%rd10];
|
||||
// inline asm
|
||||
{ cvt.f32.f16 %f27, %rs2;}
|
||||
|
||||
// inline asm
|
||||
abs.f32 %f13, %f26;
|
||||
setp.gtu.f32 %p2, %f13, 0f7F800000;
|
||||
@%p2 bra BB0_3;
|
||||
|
||||
mov.f32 %f14, 0f477FE100;
|
||||
min.f32 %f15, %f26, %f14;
|
||||
mov.f32 %f16, 0fC77FE100;
|
||||
max.f32 %f26, %f16, %f15;
|
||||
|
||||
BB0_3:
|
||||
abs.f32 %f17, %f27;
|
||||
setp.gtu.f32 %p3, %f17, 0f7F800000;
|
||||
@%p3 bra BB0_5;
|
||||
|
||||
mov.f32 %f18, 0f477FE100;
|
||||
min.f32 %f19, %f27, %f18;
|
||||
mov.f32 %f20, 0fC77FE100;
|
||||
max.f32 %f27, %f20, %f19;
|
||||
|
||||
BB0_5:
|
||||
abs.f32 %f7, %f26;
|
||||
setp.gtu.f32 %p4, %f7, 0f7F800000;
|
||||
abs.f32 %f8, %f27;
|
||||
setp.gtu.f32 %p5, %f8, 0f7F800000;
|
||||
and.pred %p6, %p4, %p5;
|
||||
@%p6 bra BB0_9;
|
||||
|
||||
sub.f32 %f21, %f26, %f27;
|
||||
abs.f32 %f22, %f21;
|
||||
max.f32 %f23, %f7, %f8;
|
||||
add.f32 %f24, %f23, 0f3F800000;
|
||||
div.rn.f32 %f9, %f22, %f24;
|
||||
setp.gt.f32 %p7, %f9, %f10;
|
||||
@%p7 bra BB0_8;
|
||||
|
||||
abs.f32 %f25, %f9;
|
||||
setp.le.f32 %p8, %f25, 0f7F800000;
|
||||
@%p8 bra BB0_9;
|
||||
|
||||
BB0_8:
|
||||
cvta.to.global.u64 %rd11, %rd3;
|
||||
atom.global.add.u32 %r5, [%rd11], 1;
|
||||
|
||||
BB0_9:
|
||||
ret;
|
||||
}
|
||||
// .globl __xla_fp32_comparison
|
||||
|
||||
// .globl __xla_fp32_comparison
|
||||
.visible .entry __xla_fp32_comparison(
|
||||
.param .u64 __xla_fp32_comparison_param_0,
|
||||
.param .u64 __xla_fp32_comparison_param_1,
|
||||
.param .f32 __xla_fp32_comparison_param_2,
|
||||
.param .u64 __xla_fp32_comparison_param_3,
|
||||
.param .u64 __xla_fp32_comparison_param_4
|
||||
.param .u64 __xla_fp32_comparison_param_0,
|
||||
.param .u64 __xla_fp32_comparison_param_1,
|
||||
.param .f32 __xla_fp32_comparison_param_2,
|
||||
.param .u64 __xla_fp32_comparison_param_3,
|
||||
.param .u64 __xla_fp32_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<12>;
|
||||
.reg .f32 %f<12>;
|
||||
.reg .b32 %r<9>;
|
||||
.reg .b64 %rd<12>;
|
||||
.reg .pred %p<10>;
|
||||
.reg .b16 %rs<3>;
|
||||
.reg .f32 %f<13>;
|
||||
.reg .b32 %r<10>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
ld.param.u64 %rd8, [__xla_fp32_comparison_param_3];
|
||||
mov.u32 %r1, %tid.x;
|
||||
mov.u32 %r2, %ctaid.x;
|
||||
mov.u32 %r3, %ntid.x;
|
||||
mad.lo.s32 %r4, %r3, %r2, %r1;
|
||||
cvt.s64.s32 %rd4, %r4;
|
||||
setp.ge.u64 %p1, %rd4, %rd8;
|
||||
@%p1 bra LBB8_6;
|
||||
ld.param.u64 %rd5, [__xla_fp32_comparison_param_0];
|
||||
ld.param.u64 %rd7, [__xla_fp32_comparison_param_1];
|
||||
cvta.to.global.u64 %rd2, %rd7;
|
||||
cvta.to.global.u64 %rd3, %rd5;
|
||||
shl.b64 %rd9, %rd4, 2;
|
||||
add.s64 %rd10, %rd3, %rd9;
|
||||
ld.global.f32 %f1, [%rd10];
|
||||
add.s64 %rd11, %rd2, %rd9;
|
||||
ld.global.f32 %f2, [%rd11];
|
||||
abs.f32 %f3, %f1;
|
||||
setp.gtu.f32 %p2, %f3, 0f7F800000;
|
||||
abs.f32 %f4, %f2;
|
||||
setp.gtu.f32 %p3, %f4, 0f7F800000;
|
||||
and.pred %p4, %p2, %p3;
|
||||
@%p4 bra LBB8_6;
|
||||
setp.neu.f32 %p5, %f3, 0f7F800000;
|
||||
setp.neu.f32 %p6, %f4, 0f7F800000;
|
||||
or.pred %p7, %p5, %p6;
|
||||
@%p7 bra LBB8_4;
|
||||
mov.b32 %r5, %f1;
|
||||
mov.b32 %r6, %f2;
|
||||
xor.b32 %r7, %r6, %r5;
|
||||
setp.gt.s32 %p8, %r7, -1;
|
||||
@%p8 bra LBB8_6;
|
||||
LBB8_4:
|
||||
ld.param.f32 %f5, [__xla_fp32_comparison_param_2];
|
||||
sub.f32 %f6, %f1, %f2;
|
||||
abs.f32 %f7, %f6;
|
||||
max.f32 %f8, %f3, %f4;
|
||||
add.f32 %f9, %f8, 0f3F800000;
|
||||
div.rn.f32 %f10, %f7, %f9;
|
||||
setp.leu.f32 %p9, %f10, %f5;
|
||||
abs.f32 %f11, %f10;
|
||||
setp.le.f32 %p10, %f11, 0f7F800000;
|
||||
and.pred %p11, %p9, %p10;
|
||||
@%p11 bra LBB8_6;
|
||||
ld.param.u64 %rd6, [__xla_fp32_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd6;
|
||||
atom.global.add.u32 %r8, [%rd1], 1;
|
||||
LBB8_6:
|
||||
ret;
|
||||
|
||||
ld.param.u64 %rd1, [__xla_fp32_comparison_param_0];
|
||||
ld.param.u64 %rd2, [__xla_fp32_comparison_param_1];
|
||||
ld.param.f32 %f6, [__xla_fp32_comparison_param_2];
|
||||
ld.param.u64 %rd4, [__xla_fp32_comparison_param_3];
|
||||
ld.param.u64 %rd3, [__xla_fp32_comparison_param_4];
|
||||
mov.u32 %r2, %ntid.x;
|
||||
mov.u32 %r3, %ctaid.x;
|
||||
mov.u32 %r4, %tid.x;
|
||||
mad.lo.s32 %r1, %r2, %r3, %r4;
|
||||
cvt.s64.s32 %rd5, %r1;
|
||||
setp.ge.u64 %p1, %rd5, %rd4;
|
||||
@%p1 bra BB1_8;
|
||||
|
||||
cvta.to.global.u64 %rd6, %rd1;
|
||||
mul.wide.s32 %rd7, %r1, 4;
|
||||
add.s64 %rd8, %rd6, %rd7;
|
||||
cvta.to.global.u64 %rd9, %rd2;
|
||||
add.s64 %rd10, %rd9, %rd7;
|
||||
ld.global.f32 %f1, [%rd10];
|
||||
ld.global.f32 %f2, [%rd8];
|
||||
abs.f32 %f3, %f2;
|
||||
setp.le.f32 %p2, %f3, 0f7F800000;
|
||||
@%p2 bra BB1_3;
|
||||
|
||||
abs.f32 %f7, %f1;
|
||||
setp.gtu.f32 %p3, %f7, 0f7F800000;
|
||||
@%p3 bra BB1_8;
|
||||
|
||||
BB1_3:
|
||||
setp.neu.f32 %p4, %f3, 0f7F800000;
|
||||
abs.f32 %f4, %f1;
|
||||
setp.neu.f32 %p5, %f4, 0f7F800000;
|
||||
or.pred %p6, %p4, %p5;
|
||||
@%p6 bra BB1_5;
|
||||
|
||||
mov.b32 %r5, %f2;
|
||||
shr.u32 %r6, %r5, 31;
|
||||
cvt.u16.u32 %rs1, %r6;
|
||||
mov.b32 %r7, %f1;
|
||||
shr.u32 %r8, %r7, 31;
|
||||
cvt.u16.u32 %rs2, %r8;
|
||||
setp.eq.s16 %p7, %rs1, %rs2;
|
||||
@%p7 bra BB1_8;
|
||||
|
||||
BB1_5:
|
||||
sub.f32 %f8, %f2, %f1;
|
||||
abs.f32 %f9, %f8;
|
||||
max.f32 %f10, %f3, %f4;
|
||||
add.f32 %f11, %f10, 0f3F800000;
|
||||
div.rn.f32 %f5, %f9, %f11;
|
||||
setp.gt.f32 %p8, %f5, %f6;
|
||||
@%p8 bra BB1_7;
|
||||
|
||||
abs.f32 %f12, %f5;
|
||||
setp.le.f32 %p9, %f12, 0f7F800000;
|
||||
@%p9 bra BB1_8;
|
||||
|
||||
BB1_7:
|
||||
cvta.to.global.u64 %rd11, %rd3;
|
||||
atom.global.add.u32 %r9, [%rd11], 1;
|
||||
|
||||
BB1_8:
|
||||
ret;
|
||||
}
|
||||
// .globl __xla_fp64_comparison
|
||||
|
||||
// .globl __xla_fp64_comparison
|
||||
.visible .entry __xla_fp64_comparison(
|
||||
.param .u64 __xla_fp64_comparison_param_0,
|
||||
.param .u64 __xla_fp64_comparison_param_1,
|
||||
.param .f32 __xla_fp64_comparison_param_2,
|
||||
.param .u64 __xla_fp64_comparison_param_3,
|
||||
.param .u64 __xla_fp64_comparison_param_4
|
||||
.param .u64 __xla_fp64_comparison_param_0,
|
||||
.param .u64 __xla_fp64_comparison_param_1,
|
||||
.param .f32 __xla_fp64_comparison_param_2,
|
||||
.param .u64 __xla_fp64_comparison_param_3,
|
||||
.param .u64 __xla_fp64_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<16>;
|
||||
.reg .f32 %f<2>;
|
||||
.reg .b32 %r<13>;
|
||||
.reg .f64 %fd<12>;
|
||||
.reg .b64 %rd<12>;
|
||||
.reg .pred %p<11>;
|
||||
.reg .b16 %rs<3>;
|
||||
.reg .f32 %f<2>;
|
||||
.reg .b32 %r<14>;
|
||||
.reg .f64 %fd<13>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
ld.param.u64 %rd8, [__xla_fp64_comparison_param_3];
|
||||
mov.u32 %r2, %tid.x;
|
||||
mov.u32 %r3, %ctaid.x;
|
||||
mov.u32 %r4, %ntid.x;
|
||||
mad.lo.s32 %r5, %r4, %r3, %r2;
|
||||
cvt.s64.s32 %rd4, %r5;
|
||||
setp.ge.u64 %p1, %rd4, %rd8;
|
||||
@%p1 bra LBB9_6;
|
||||
ld.param.u64 %rd5, [__xla_fp64_comparison_param_0];
|
||||
ld.param.u64 %rd7, [__xla_fp64_comparison_param_1];
|
||||
cvta.to.global.u64 %rd2, %rd7;
|
||||
cvta.to.global.u64 %rd3, %rd5;
|
||||
shl.b64 %rd9, %rd4, 3;
|
||||
add.s64 %rd10, %rd3, %rd9;
|
||||
ld.global.f64 %fd1, [%rd10];
|
||||
add.s64 %rd11, %rd2, %rd9;
|
||||
ld.global.f64 %fd2, [%rd11];
|
||||
abs.f64 %fd3, %fd1;
|
||||
setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000;
|
||||
abs.f64 %fd4, %fd2;
|
||||
setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000;
|
||||
and.pred %p4, %p2, %p3;
|
||||
@%p4 bra LBB9_6;
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%r6, %temp}, %fd1;
|
||||
}
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%temp, %r1}, %fd1;
|
||||
}
|
||||
and.b32 %r7, %r1, 2147483647;
|
||||
setp.ne.s32 %p5, %r7, 2146435072;
|
||||
setp.ne.s32 %p6, %r6, 0;
|
||||
or.pred %p7, %p6, %p5;
|
||||
@%p7 bra LBB9_4;
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%r8, %temp}, %fd2;
|
||||
}
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%temp, %r9}, %fd2;
|
||||
}
|
||||
and.b32 %r10, %r9, 2147483647;
|
||||
setp.eq.s32 %p8, %r10, 2146435072;
|
||||
setp.eq.s32 %p9, %r8, 0;
|
||||
and.pred %p10, %p8, %p9;
|
||||
xor.b32 %r11, %r9, %r1;
|
||||
setp.gt.s32 %p11, %r11, -1;
|
||||
and.pred %p12, %p11, %p10;
|
||||
@%p12 bra LBB9_6;
|
||||
LBB9_4:
|
||||
ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
|
||||
sub.f64 %fd5, %fd1, %fd2;
|
||||
abs.f64 %fd6, %fd5;
|
||||
max.f64 %fd7, %fd3, %fd4;
|
||||
add.f64 %fd8, %fd7, 0d3FF0000000000000;
|
||||
div.rn.f64 %fd9, %fd6, %fd8;
|
||||
cvt.f64.f32 %fd10, %f1;
|
||||
setp.leu.f64 %p13, %fd9, %fd10;
|
||||
abs.f64 %fd11, %fd9;
|
||||
setp.le.f64 %p14, %fd11, 0d7FF0000000000000;
|
||||
and.pred %p15, %p13, %p14;
|
||||
@%p15 bra LBB9_6;
|
||||
ld.param.u64 %rd6, [__xla_fp64_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd6;
|
||||
atom.global.add.u32 %r12, [%rd1], 1;
|
||||
LBB9_6:
|
||||
ret;
|
||||
|
||||
ld.param.u64 %rd1, [__xla_fp64_comparison_param_0];
|
||||
ld.param.u64 %rd2, [__xla_fp64_comparison_param_1];
|
||||
ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
|
||||
ld.param.u64 %rd4, [__xla_fp64_comparison_param_3];
|
||||
ld.param.u64 %rd3, [__xla_fp64_comparison_param_4];
|
||||
mov.u32 %r4, %ntid.x;
|
||||
mov.u32 %r5, %ctaid.x;
|
||||
mov.u32 %r6, %tid.x;
|
||||
mad.lo.s32 %r1, %r4, %r5, %r6;
|
||||
cvt.s64.s32 %rd5, %r1;
|
||||
setp.ge.u64 %p1, %rd5, %rd4;
|
||||
@%p1 bra BB2_11;
|
||||
|
||||
cvta.to.global.u64 %rd6, %rd1;
|
||||
mul.wide.s32 %rd7, %r1, 8;
|
||||
add.s64 %rd8, %rd6, %rd7;
|
||||
cvta.to.global.u64 %rd9, %rd2;
|
||||
add.s64 %rd10, %rd9, %rd7;
|
||||
ld.global.f64 %fd1, [%rd10];
|
||||
ld.global.f64 %fd2, [%rd8];
|
||||
abs.f64 %fd3, %fd2;
|
||||
setp.le.f64 %p2, %fd3, 0d7FF0000000000000;
|
||||
@%p2 bra BB2_3;
|
||||
|
||||
abs.f64 %fd5, %fd1;
|
||||
setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000;
|
||||
@%p3 bra BB2_11;
|
||||
|
||||
BB2_3:
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%temp, %r2}, %fd2;
|
||||
}
|
||||
and.b32 %r7, %r2, 2147483647;
|
||||
setp.ne.s32 %p4, %r7, 2146435072;
|
||||
@%p4 bra BB2_8;
|
||||
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%r8, %temp}, %fd2;
|
||||
}
|
||||
setp.ne.s32 %p5, %r8, 0;
|
||||
@%p5 bra BB2_8;
|
||||
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%temp, %r3}, %fd1;
|
||||
}
|
||||
and.b32 %r9, %r3, 2147483647;
|
||||
setp.ne.s32 %p6, %r9, 2146435072;
|
||||
@%p6 bra BB2_8;
|
||||
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%r10, %temp}, %fd1;
|
||||
}
|
||||
setp.ne.s32 %p7, %r10, 0;
|
||||
@%p7 bra BB2_8;
|
||||
|
||||
shr.u32 %r11, %r2, 31;
|
||||
cvt.u16.u32 %rs1, %r11;
|
||||
shr.u32 %r12, %r3, 31;
|
||||
cvt.u16.u32 %rs2, %r12;
|
||||
setp.eq.s16 %p8, %rs1, %rs2;
|
||||
@%p8 bra BB2_11;
|
||||
|
||||
BB2_8:
|
||||
sub.f64 %fd6, %fd2, %fd1;
|
||||
abs.f64 %fd7, %fd6;
|
||||
abs.f64 %fd8, %fd1;
|
||||
max.f64 %fd9, %fd3, %fd8;
|
||||
add.f64 %fd10, %fd9, 0d3FF0000000000000;
|
||||
div.rn.f64 %fd4, %fd7, %fd10;
|
||||
cvt.f64.f32 %fd11, %f1;
|
||||
setp.gt.f64 %p9, %fd4, %fd11;
|
||||
@%p9 bra BB2_10;
|
||||
|
||||
abs.f64 %fd12, %fd4;
|
||||
setp.le.f64 %p10, %fd12, 0d7FF0000000000000;
|
||||
@%p10 bra BB2_11;
|
||||
|
||||
BB2_10:
|
||||
cvta.to.global.u64 %rd11, %rd3;
|
||||
atom.global.add.u32 %r13, [%rd11], 1;
|
||||
|
||||
BB2_11:
|
||||
ret;
|
||||
}
|
||||
|
||||
// .globl __xla_int8_comparison
|
||||
.visible .entry __xla_int8_comparison(
|
||||
.param .u64 __xla_int8_comparison_param_0,
|
||||
.param .u64 __xla_int8_comparison_param_1,
|
||||
.param .f32 __xla_int8_comparison_param_2,
|
||||
.param .u64 __xla_int8_comparison_param_3,
|
||||
.param .u64 __xla_int8_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<10>;
|
||||
.reg .f32 %f<42>;
|
||||
.reg .b32 %r<23>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
|
||||
ld.param.u64 %rd2, [__xla_int8_comparison_param_0];
|
||||
ld.param.u64 %rd3, [__xla_int8_comparison_param_1];
|
||||
ld.param.f32 %f5, [__xla_int8_comparison_param_2];
|
||||
ld.param.u64 %rd4, [__xla_int8_comparison_param_3];
|
||||
ld.param.u64 %rd5, [__xla_int8_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd5;
|
||||
mov.u32 %r4, %ntid.x;
|
||||
mov.u32 %r5, %ctaid.x;
|
||||
mov.u32 %r6, %tid.x;
|
||||
mad.lo.s32 %r1, %r4, %r5, %r6;
|
||||
cvt.s64.s32 %rd6, %r1;
|
||||
setp.ge.u64 %p1, %rd6, %rd4;
|
||||
@%p1 bra BB3_13;
|
||||
|
||||
cvta.to.global.u64 %rd7, %rd2;
|
||||
mul.wide.s32 %rd8, %r1, 4;
|
||||
add.s64 %rd9, %rd7, %rd8;
|
||||
cvta.to.global.u64 %rd10, %rd3;
|
||||
add.s64 %rd11, %rd10, %rd8;
|
||||
ld.global.u32 %r2, [%rd9];
|
||||
cvt.s32.s8 %r7, %r2;
|
||||
cvt.rn.f32.s32 %f6, %r7;
|
||||
ld.global.u32 %r3, [%rd11];
|
||||
cvt.s32.s8 %r8, %r3;
|
||||
cvt.rn.f32.s32 %f7, %r8;
|
||||
sub.f32 %f8, %f6, %f7;
|
||||
abs.f32 %f9, %f8;
|
||||
abs.f32 %f10, %f6;
|
||||
abs.f32 %f11, %f7;
|
||||
max.f32 %f12, %f10, %f11;
|
||||
add.f32 %f13, %f12, 0f3F800000;
|
||||
div.rn.f32 %f1, %f9, %f13;
|
||||
setp.gt.f32 %p2, %f1, %f5;
|
||||
@%p2 bra BB3_3;
|
||||
|
||||
abs.f32 %f14, %f1;
|
||||
setp.le.f32 %p3, %f14, 0f7F800000;
|
||||
@%p3 bra BB3_4;
|
||||
|
||||
BB3_3:
|
||||
atom.global.add.u32 %r9, [%rd1], 1;
|
||||
|
||||
BB3_4:
|
||||
shr.u32 %r10, %r3, 8;
|
||||
shr.u32 %r11, %r2, 8;
|
||||
cvt.s32.s8 %r12, %r11;
|
||||
cvt.rn.f32.s32 %f15, %r12;
|
||||
cvt.s32.s8 %r13, %r10;
|
||||
cvt.rn.f32.s32 %f16, %r13;
|
||||
sub.f32 %f17, %f15, %f16;
|
||||
abs.f32 %f18, %f17;
|
||||
abs.f32 %f19, %f15;
|
||||
abs.f32 %f20, %f16;
|
||||
max.f32 %f21, %f19, %f20;
|
||||
add.f32 %f22, %f21, 0f3F800000;
|
||||
div.rn.f32 %f2, %f18, %f22;
|
||||
setp.gt.f32 %p4, %f2, %f5;
|
||||
@%p4 bra BB3_6;
|
||||
|
||||
abs.f32 %f23, %f2;
|
||||
setp.le.f32 %p5, %f23, 0f7F800000;
|
||||
@%p5 bra BB3_7;
|
||||
|
||||
BB3_6:
|
||||
atom.global.add.u32 %r14, [%rd1], 1;
|
||||
|
||||
BB3_7:
|
||||
shr.u32 %r15, %r3, 16;
|
||||
shr.u32 %r16, %r2, 16;
|
||||
cvt.s32.s8 %r17, %r16;
|
||||
cvt.rn.f32.s32 %f24, %r17;
|
||||
cvt.s32.s8 %r18, %r15;
|
||||
cvt.rn.f32.s32 %f25, %r18;
|
||||
sub.f32 %f26, %f24, %f25;
|
||||
abs.f32 %f27, %f26;
|
||||
abs.f32 %f28, %f24;
|
||||
abs.f32 %f29, %f25;
|
||||
max.f32 %f30, %f28, %f29;
|
||||
add.f32 %f31, %f30, 0f3F800000;
|
||||
div.rn.f32 %f3, %f27, %f31;
|
||||
setp.gt.f32 %p6, %f3, %f5;
|
||||
@%p6 bra BB3_9;
|
||||
|
||||
abs.f32 %f32, %f3;
|
||||
setp.le.f32 %p7, %f32, 0f7F800000;
|
||||
@%p7 bra BB3_10;
|
||||
|
||||
BB3_9:
|
||||
atom.global.add.u32 %r19, [%rd1], 1;
|
||||
|
||||
BB3_10:
|
||||
shr.s32 %r20, %r2, 24;
|
||||
cvt.rn.f32.s32 %f33, %r20;
|
||||
shr.s32 %r21, %r3, 24;
|
||||
cvt.rn.f32.s32 %f34, %r21;
|
||||
sub.f32 %f35, %f33, %f34;
|
||||
abs.f32 %f36, %f35;
|
||||
abs.f32 %f37, %f33;
|
||||
abs.f32 %f38, %f34;
|
||||
max.f32 %f39, %f37, %f38;
|
||||
add.f32 %f40, %f39, 0f3F800000;
|
||||
div.rn.f32 %f4, %f36, %f40;
|
||||
setp.gt.f32 %p8, %f4, %f5;
|
||||
@%p8 bra BB3_12;
|
||||
|
||||
abs.f32 %f41, %f4;
|
||||
setp.le.f32 %p9, %f41, 0f7F800000;
|
||||
@%p9 bra BB3_13;
|
||||
|
||||
BB3_12:
|
||||
atom.global.add.u32 %r22, [%rd1], 1;
|
||||
|
||||
BB3_13:
|
||||
ret;
|
||||
}
|
||||
)";
|
||||
|
||||
@ -405,11 +619,13 @@ StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
|
||||
|
||||
const auto canonicalize = [](ComparisonType a) -> ComparisonType {
|
||||
if (std::is_same<ElementType, Eigen::half>::value && a) {
|
||||
constexpr ComparisonType kMaxFp16Value = 65505.;
|
||||
constexpr ComparisonType kMaxFp16Value =
|
||||
std::is_same<ElementType, Eigen::half>::value ? 65505. : 0;
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
|
||||
return std::max(static_cast<ComparisonType>(-kMaxFp16Value),
|
||||
static_cast<ComparisonType>(std::min(a, kMaxFp16Value)));
|
||||
}
|
||||
return a;
|
||||
};
|
||||
@ -472,6 +688,9 @@ StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
|
||||
case xla::F64:
|
||||
return CompareEqualParameterized<double, double>(
|
||||
stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
|
||||
case xla::S8:
|
||||
return CompareEqualParameterized<int8, float>(
|
||||
stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
|
||||
default:
|
||||
return Unimplemented("Unimplemented element type");
|
||||
}
|
||||
|
@ -178,6 +178,13 @@ TEST_F(BufferComparatorTest, TestNumbers) {
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<double>({0.9}, {1}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<double>({9}, {10}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<double>({10}, {9}));
|
||||
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({200}, {201}));
|
||||
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({0}, {10}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({9}, {10}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({90}, {100}));
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({100}, {90}));
|
||||
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({-128}, {127}));
|
||||
}
|
||||
|
||||
TEST_F(BufferComparatorTest, TestMultiple) {
|
||||
@ -231,6 +238,23 @@ TEST_F(BufferComparatorTest, TestMultiple) {
|
||||
rhs[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({20, 30, 40, 50, 60},
|
||||
{21, 31, 41, 51, 61}));
|
||||
std::vector<float> lhs(200);
|
||||
std::vector<float> rhs(200);
|
||||
for (int i = 0; i < 200; i++) {
|
||||
EXPECT_TRUE(CompareEqualFloatBuffers<int8>(lhs, rhs))
|
||||
<< "should be the same at index " << i;
|
||||
lhs[i] = 3;
|
||||
rhs[i] = 5;
|
||||
EXPECT_FALSE(CompareEqualFloatBuffers<int8>(lhs, rhs))
|
||||
<< "should be the different at index " << i;
|
||||
lhs[i] = 0;
|
||||
rhs[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user