Merge pull request #31988 from yongfeng-nv:xla_int8_buffer_comparator

PiperOrigin-RevId: 265890572
This commit is contained in:
TensorFlower Gardener 2019-08-28 05:49:01 -07:00
commit c4fc64c728
2 changed files with 458 additions and 215 deletions

View File

@ -34,7 +34,7 @@ namespace gpu {
static constexpr double kTolerance = 0.1f;
// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length
// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
// buffer_length where the relative error does not exceed the passed
// rel_error_threshold. Write the number of mismatches into out parameter
// mismatch_count.
@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f;
//
// #include<cuda_fp16.h>
// extern "C" { // avoid name mangling
// __device__ float canonicalize(float input) {
// __device__ float __xla_buffer_comparator_canonicalize(float input) {
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
// // differences due to overflows.
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
// }
//
// __device__ float __xla_buffer_comparator_extract_int8(int pack) {
// // Extract the lower 8 bits from pack and convert it to float
// const unsigned int bit_mask = 0xff;
// unsigned int bits = pack & bit_mask;
// char* int8_ptr = (char*)&bits;
// return __int2float_rn(*int8_ptr);
// }
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f;
// if (idx >= buffer_length) return;
// float elem_a = __half2float(buffer_a[idx]);
// float elem_b = __half2float(buffer_b[idx]);
// elem_a = canonicalize(elem_a);
// elem_b = canonicalize(elem_b);
// elem_a = __xla_buffer_comparator_canonicalize(elem_a);
// elem_b = __xla_buffer_comparator_canonicalize(elem_b);
// if (isnan(elem_a) && isnan(elem_b)) return;
// float rel_error = abs(elem_a - elem_b)
// / (max(abs(elem_a), abs(elem_b)) + 1);
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// }
//
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f;
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// }
//
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
@ -102,234 +110,440 @@ static constexpr double kTolerance = 0.1f;
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// }
// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
// float rel_error_threshold,
// unsigned long long buffer_length,
// int* mismatch_count) {
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
// if (idx >= buffer_length) return;
// int pack_a = buffer_a[idx];
// int pack_b = buffer_b[idx];
// for(int i = 0; i < 4; ++i) {
// float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
// float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
// float rel_error = abs(elem_a - elem_b)
// / (max(abs(elem_a), abs(elem_b)) + 1);
// if (rel_error > rel_error_threshold || isnan(rel_error))
// atomicAdd(mismatch_count, 1);
// pack_a >>= 8;
// pack_b >>= 8;
// }
// }
// } // end extern declaration.
static const char* buffer_compare_ptx = R"(
.version 4.2
.version 6.4
.target sm_30
.address_size 64
// .globl __xla_fp16_comparison
.visible .entry __xla_fp16_comparison(
.param .u64 __xla_fp16_comparison_param_0,
.param .u64 __xla_fp16_comparison_param_1,
.param .f32 __xla_fp16_comparison_param_2,
.param .u64 __xla_fp16_comparison_param_3,
.param .u64 __xla_fp16_comparison_param_4
.param .u64 __xla_fp16_comparison_param_0,
.param .u64 __xla_fp16_comparison_param_1,
.param .f32 __xla_fp16_comparison_param_2,
.param .u64 __xla_fp16_comparison_param_3,
.param .u64 __xla_fp16_comparison_param_4
)
{
.reg .pred %p<10>;
.reg .b16 %rs<3>;
.reg .f32 %f<20>;
.reg .b32 %r<6>;
.reg .b64 %rd<12>;
ld.param.u64 %rd8, [__xla_fp16_comparison_param_3];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.s32 %r4, %r3, %r2, %r1;
cvt.s64.s32 %rd4, %r4;
setp.ge.u64 %p1, %rd4, %rd8;
@%p1 bra LBB7_4;
ld.param.u64 %rd5, [__xla_fp16_comparison_param_0];
ld.param.u64 %rd7, [__xla_fp16_comparison_param_1];
cvta.to.global.u64 %rd2, %rd7;
cvta.to.global.u64 %rd3, %rd5;
shl.b64 %rd9, %rd4, 1;
add.s64 %rd10, %rd3, %rd9;
ld.global.u16 %rs1, [%rd10];
// begin inline asm
{ cvt.f32.f16 %f6, %rs1;}
.reg .pred %p<9>;
.reg .b16 %rs<3>;
.reg .f32 %f<28>;
.reg .b32 %r<6>;
.reg .b64 %rd<12>;
// end inline asm
add.s64 %rd11, %rd2, %rd9;
ld.global.u16 %rs2, [%rd11];
// begin inline asm
{ cvt.f32.f16 %f7, %rs2;}
// end inline asm
abs.f32 %f8, %f6;
setp.gtu.f32 %p2, %f8, 0f7F800000;
min.f32 %f9, %f6, 0f477FE100;
max.f32 %f10, %f9, 0fC77FE100;
selp.f32 %f1, %f6, %f10, %p2;
abs.f32 %f11, %f7;
setp.gtu.f32 %p3, %f11, 0f7F800000;
min.f32 %f12, %f7, 0f477FE100;
max.f32 %f13, %f12, 0fC77FE100;
selp.f32 %f2, %f7, %f13, %p3;
abs.f32 %f3, %f1;
setp.gtu.f32 %p4, %f3, 0f7F800000;
abs.f32 %f4, %f2;
setp.gtu.f32 %p5, %f4, 0f7F800000;
and.pred %p6, %p4, %p5;
@%p6 bra LBB7_4;
ld.param.f32 %f5, [__xla_fp16_comparison_param_2];
sub.f32 %f14, %f1, %f2;
abs.f32 %f15, %f14;
max.f32 %f16, %f3, %f4;
add.f32 %f17, %f16, 0f3F800000;
div.rn.f32 %f18, %f15, %f17;
setp.leu.f32 %p7, %f18, %f5;
abs.f32 %f19, %f18;
setp.le.f32 %p8, %f19, 0f7F800000;
and.pred %p9, %p7, %p8;
@%p9 bra LBB7_4;
ld.param.u64 %rd6, [__xla_fp16_comparison_param_4];
cvta.to.global.u64 %rd1, %rd6;
atom.global.add.u32 %r5, [%rd1], 1;
LBB7_4:
ret;
ld.param.u64 %rd1, [__xla_fp16_comparison_param_0];
ld.param.u64 %rd2, [__xla_fp16_comparison_param_1];
ld.param.f32 %f10, [__xla_fp16_comparison_param_2];
ld.param.u64 %rd4, [__xla_fp16_comparison_param_3];
ld.param.u64 %rd3, [__xla_fp16_comparison_param_4];
mov.u32 %r2, %ntid.x;
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %tid.x;
mad.lo.s32 %r1, %r2, %r3, %r4;
cvt.s64.s32 %rd5, %r1;
setp.ge.u64 %p1, %rd5, %rd4;
@%p1 bra BB0_9;
cvta.to.global.u64 %rd6, %rd1;
mul.wide.s32 %rd7, %r1, 2;
add.s64 %rd8, %rd6, %rd7;
ld.global.u16 %rs1, [%rd8];
// inline asm
{ cvt.f32.f16 %f26, %rs1;}
// inline asm
cvta.to.global.u64 %rd9, %rd2;
add.s64 %rd10, %rd9, %rd7;
ld.global.u16 %rs2, [%rd10];
// inline asm
{ cvt.f32.f16 %f27, %rs2;}
// inline asm
abs.f32 %f13, %f26;
setp.gtu.f32 %p2, %f13, 0f7F800000;
@%p2 bra BB0_3;
mov.f32 %f14, 0f477FE100;
min.f32 %f15, %f26, %f14;
mov.f32 %f16, 0fC77FE100;
max.f32 %f26, %f16, %f15;
BB0_3:
abs.f32 %f17, %f27;
setp.gtu.f32 %p3, %f17, 0f7F800000;
@%p3 bra BB0_5;
mov.f32 %f18, 0f477FE100;
min.f32 %f19, %f27, %f18;
mov.f32 %f20, 0fC77FE100;
max.f32 %f27, %f20, %f19;
BB0_5:
abs.f32 %f7, %f26;
setp.gtu.f32 %p4, %f7, 0f7F800000;
abs.f32 %f8, %f27;
setp.gtu.f32 %p5, %f8, 0f7F800000;
and.pred %p6, %p4, %p5;
@%p6 bra BB0_9;
sub.f32 %f21, %f26, %f27;
abs.f32 %f22, %f21;
max.f32 %f23, %f7, %f8;
add.f32 %f24, %f23, 0f3F800000;
div.rn.f32 %f9, %f22, %f24;
setp.gt.f32 %p7, %f9, %f10;
@%p7 bra BB0_8;
abs.f32 %f25, %f9;
setp.le.f32 %p8, %f25, 0f7F800000;
@%p8 bra BB0_9;
BB0_8:
cvta.to.global.u64 %rd11, %rd3;
atom.global.add.u32 %r5, [%rd11], 1;
BB0_9:
ret;
}
// .globl __xla_fp32_comparison
// .globl __xla_fp32_comparison
.visible .entry __xla_fp32_comparison(
.param .u64 __xla_fp32_comparison_param_0,
.param .u64 __xla_fp32_comparison_param_1,
.param .f32 __xla_fp32_comparison_param_2,
.param .u64 __xla_fp32_comparison_param_3,
.param .u64 __xla_fp32_comparison_param_4
.param .u64 __xla_fp32_comparison_param_0,
.param .u64 __xla_fp32_comparison_param_1,
.param .f32 __xla_fp32_comparison_param_2,
.param .u64 __xla_fp32_comparison_param_3,
.param .u64 __xla_fp32_comparison_param_4
)
{
.reg .pred %p<12>;
.reg .f32 %f<12>;
.reg .b32 %r<9>;
.reg .b64 %rd<12>;
.reg .pred %p<10>;
.reg .b16 %rs<3>;
.reg .f32 %f<13>;
.reg .b32 %r<10>;
.reg .b64 %rd<12>;
ld.param.u64 %rd8, [__xla_fp32_comparison_param_3];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %ntid.x;
mad.lo.s32 %r4, %r3, %r2, %r1;
cvt.s64.s32 %rd4, %r4;
setp.ge.u64 %p1, %rd4, %rd8;
@%p1 bra LBB8_6;
ld.param.u64 %rd5, [__xla_fp32_comparison_param_0];
ld.param.u64 %rd7, [__xla_fp32_comparison_param_1];
cvta.to.global.u64 %rd2, %rd7;
cvta.to.global.u64 %rd3, %rd5;
shl.b64 %rd9, %rd4, 2;
add.s64 %rd10, %rd3, %rd9;
ld.global.f32 %f1, [%rd10];
add.s64 %rd11, %rd2, %rd9;
ld.global.f32 %f2, [%rd11];
abs.f32 %f3, %f1;
setp.gtu.f32 %p2, %f3, 0f7F800000;
abs.f32 %f4, %f2;
setp.gtu.f32 %p3, %f4, 0f7F800000;
and.pred %p4, %p2, %p3;
@%p4 bra LBB8_6;
setp.neu.f32 %p5, %f3, 0f7F800000;
setp.neu.f32 %p6, %f4, 0f7F800000;
or.pred %p7, %p5, %p6;
@%p7 bra LBB8_4;
mov.b32 %r5, %f1;
mov.b32 %r6, %f2;
xor.b32 %r7, %r6, %r5;
setp.gt.s32 %p8, %r7, -1;
@%p8 bra LBB8_6;
LBB8_4:
ld.param.f32 %f5, [__xla_fp32_comparison_param_2];
sub.f32 %f6, %f1, %f2;
abs.f32 %f7, %f6;
max.f32 %f8, %f3, %f4;
add.f32 %f9, %f8, 0f3F800000;
div.rn.f32 %f10, %f7, %f9;
setp.leu.f32 %p9, %f10, %f5;
abs.f32 %f11, %f10;
setp.le.f32 %p10, %f11, 0f7F800000;
and.pred %p11, %p9, %p10;
@%p11 bra LBB8_6;
ld.param.u64 %rd6, [__xla_fp32_comparison_param_4];
cvta.to.global.u64 %rd1, %rd6;
atom.global.add.u32 %r8, [%rd1], 1;
LBB8_6:
ret;
ld.param.u64 %rd1, [__xla_fp32_comparison_param_0];
ld.param.u64 %rd2, [__xla_fp32_comparison_param_1];
ld.param.f32 %f6, [__xla_fp32_comparison_param_2];
ld.param.u64 %rd4, [__xla_fp32_comparison_param_3];
ld.param.u64 %rd3, [__xla_fp32_comparison_param_4];
mov.u32 %r2, %ntid.x;
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %tid.x;
mad.lo.s32 %r1, %r2, %r3, %r4;
cvt.s64.s32 %rd5, %r1;
setp.ge.u64 %p1, %rd5, %rd4;
@%p1 bra BB1_8;
cvta.to.global.u64 %rd6, %rd1;
mul.wide.s32 %rd7, %r1, 4;
add.s64 %rd8, %rd6, %rd7;
cvta.to.global.u64 %rd9, %rd2;
add.s64 %rd10, %rd9, %rd7;
ld.global.f32 %f1, [%rd10];
ld.global.f32 %f2, [%rd8];
abs.f32 %f3, %f2;
setp.le.f32 %p2, %f3, 0f7F800000;
@%p2 bra BB1_3;
abs.f32 %f7, %f1;
setp.gtu.f32 %p3, %f7, 0f7F800000;
@%p3 bra BB1_8;
BB1_3:
setp.neu.f32 %p4, %f3, 0f7F800000;
abs.f32 %f4, %f1;
setp.neu.f32 %p5, %f4, 0f7F800000;
or.pred %p6, %p4, %p5;
@%p6 bra BB1_5;
mov.b32 %r5, %f2;
shr.u32 %r6, %r5, 31;
cvt.u16.u32 %rs1, %r6;
mov.b32 %r7, %f1;
shr.u32 %r8, %r7, 31;
cvt.u16.u32 %rs2, %r8;
setp.eq.s16 %p7, %rs1, %rs2;
@%p7 bra BB1_8;
BB1_5:
sub.f32 %f8, %f2, %f1;
abs.f32 %f9, %f8;
max.f32 %f10, %f3, %f4;
add.f32 %f11, %f10, 0f3F800000;
div.rn.f32 %f5, %f9, %f11;
setp.gt.f32 %p8, %f5, %f6;
@%p8 bra BB1_7;
abs.f32 %f12, %f5;
setp.le.f32 %p9, %f12, 0f7F800000;
@%p9 bra BB1_8;
BB1_7:
cvta.to.global.u64 %rd11, %rd3;
atom.global.add.u32 %r9, [%rd11], 1;
BB1_8:
ret;
}
// .globl __xla_fp64_comparison
// .globl __xla_fp64_comparison
.visible .entry __xla_fp64_comparison(
.param .u64 __xla_fp64_comparison_param_0,
.param .u64 __xla_fp64_comparison_param_1,
.param .f32 __xla_fp64_comparison_param_2,
.param .u64 __xla_fp64_comparison_param_3,
.param .u64 __xla_fp64_comparison_param_4
.param .u64 __xla_fp64_comparison_param_0,
.param .u64 __xla_fp64_comparison_param_1,
.param .f32 __xla_fp64_comparison_param_2,
.param .u64 __xla_fp64_comparison_param_3,
.param .u64 __xla_fp64_comparison_param_4
)
{
.reg .pred %p<16>;
.reg .f32 %f<2>;
.reg .b32 %r<13>;
.reg .f64 %fd<12>;
.reg .b64 %rd<12>;
.reg .pred %p<11>;
.reg .b16 %rs<3>;
.reg .f32 %f<2>;
.reg .b32 %r<14>;
.reg .f64 %fd<13>;
.reg .b64 %rd<12>;
ld.param.u64 %rd8, [__xla_fp64_comparison_param_3];
mov.u32 %r2, %tid.x;
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %ntid.x;
mad.lo.s32 %r5, %r4, %r3, %r2;
cvt.s64.s32 %rd4, %r5;
setp.ge.u64 %p1, %rd4, %rd8;
@%p1 bra LBB9_6;
ld.param.u64 %rd5, [__xla_fp64_comparison_param_0];
ld.param.u64 %rd7, [__xla_fp64_comparison_param_1];
cvta.to.global.u64 %rd2, %rd7;
cvta.to.global.u64 %rd3, %rd5;
shl.b64 %rd9, %rd4, 3;
add.s64 %rd10, %rd3, %rd9;
ld.global.f64 %fd1, [%rd10];
add.s64 %rd11, %rd2, %rd9;
ld.global.f64 %fd2, [%rd11];
abs.f64 %fd3, %fd1;
setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000;
abs.f64 %fd4, %fd2;
setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000;
and.pred %p4, %p2, %p3;
@%p4 bra LBB9_6;
{
.reg .b32 %temp;
mov.b64 {%r6, %temp}, %fd1;
}
{
.reg .b32 %temp;
mov.b64 {%temp, %r1}, %fd1;
}
and.b32 %r7, %r1, 2147483647;
setp.ne.s32 %p5, %r7, 2146435072;
setp.ne.s32 %p6, %r6, 0;
or.pred %p7, %p6, %p5;
@%p7 bra LBB9_4;
{
.reg .b32 %temp;
mov.b64 {%r8, %temp}, %fd2;
}
{
.reg .b32 %temp;
mov.b64 {%temp, %r9}, %fd2;
}
and.b32 %r10, %r9, 2147483647;
setp.eq.s32 %p8, %r10, 2146435072;
setp.eq.s32 %p9, %r8, 0;
and.pred %p10, %p8, %p9;
xor.b32 %r11, %r9, %r1;
setp.gt.s32 %p11, %r11, -1;
and.pred %p12, %p11, %p10;
@%p12 bra LBB9_6;
LBB9_4:
ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
sub.f64 %fd5, %fd1, %fd2;
abs.f64 %fd6, %fd5;
max.f64 %fd7, %fd3, %fd4;
add.f64 %fd8, %fd7, 0d3FF0000000000000;
div.rn.f64 %fd9, %fd6, %fd8;
cvt.f64.f32 %fd10, %f1;
setp.leu.f64 %p13, %fd9, %fd10;
abs.f64 %fd11, %fd9;
setp.le.f64 %p14, %fd11, 0d7FF0000000000000;
and.pred %p15, %p13, %p14;
@%p15 bra LBB9_6;
ld.param.u64 %rd6, [__xla_fp64_comparison_param_4];
cvta.to.global.u64 %rd1, %rd6;
atom.global.add.u32 %r12, [%rd1], 1;
LBB9_6:
ret;
ld.param.u64 %rd1, [__xla_fp64_comparison_param_0];
ld.param.u64 %rd2, [__xla_fp64_comparison_param_1];
ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
ld.param.u64 %rd4, [__xla_fp64_comparison_param_3];
ld.param.u64 %rd3, [__xla_fp64_comparison_param_4];
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %ctaid.x;
mov.u32 %r6, %tid.x;
mad.lo.s32 %r1, %r4, %r5, %r6;
cvt.s64.s32 %rd5, %r1;
setp.ge.u64 %p1, %rd5, %rd4;
@%p1 bra BB2_11;
cvta.to.global.u64 %rd6, %rd1;
mul.wide.s32 %rd7, %r1, 8;
add.s64 %rd8, %rd6, %rd7;
cvta.to.global.u64 %rd9, %rd2;
add.s64 %rd10, %rd9, %rd7;
ld.global.f64 %fd1, [%rd10];
ld.global.f64 %fd2, [%rd8];
abs.f64 %fd3, %fd2;
setp.le.f64 %p2, %fd3, 0d7FF0000000000000;
@%p2 bra BB2_3;
abs.f64 %fd5, %fd1;
setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000;
@%p3 bra BB2_11;
BB2_3:
{
.reg .b32 %temp;
mov.b64 {%temp, %r2}, %fd2;
}
and.b32 %r7, %r2, 2147483647;
setp.ne.s32 %p4, %r7, 2146435072;
@%p4 bra BB2_8;
{
.reg .b32 %temp;
mov.b64 {%r8, %temp}, %fd2;
}
setp.ne.s32 %p5, %r8, 0;
@%p5 bra BB2_8;
{
.reg .b32 %temp;
mov.b64 {%temp, %r3}, %fd1;
}
and.b32 %r9, %r3, 2147483647;
setp.ne.s32 %p6, %r9, 2146435072;
@%p6 bra BB2_8;
{
.reg .b32 %temp;
mov.b64 {%r10, %temp}, %fd1;
}
setp.ne.s32 %p7, %r10, 0;
@%p7 bra BB2_8;
shr.u32 %r11, %r2, 31;
cvt.u16.u32 %rs1, %r11;
shr.u32 %r12, %r3, 31;
cvt.u16.u32 %rs2, %r12;
setp.eq.s16 %p8, %rs1, %rs2;
@%p8 bra BB2_11;
BB2_8:
sub.f64 %fd6, %fd2, %fd1;
abs.f64 %fd7, %fd6;
abs.f64 %fd8, %fd1;
max.f64 %fd9, %fd3, %fd8;
add.f64 %fd10, %fd9, 0d3FF0000000000000;
div.rn.f64 %fd4, %fd7, %fd10;
cvt.f64.f32 %fd11, %f1;
setp.gt.f64 %p9, %fd4, %fd11;
@%p9 bra BB2_10;
abs.f64 %fd12, %fd4;
setp.le.f64 %p10, %fd12, 0d7FF0000000000000;
@%p10 bra BB2_11;
BB2_10:
cvta.to.global.u64 %rd11, %rd3;
atom.global.add.u32 %r13, [%rd11], 1;
BB2_11:
ret;
}
// .globl __xla_int8_comparison
.visible .entry __xla_int8_comparison(
.param .u64 __xla_int8_comparison_param_0,
.param .u64 __xla_int8_comparison_param_1,
.param .f32 __xla_int8_comparison_param_2,
.param .u64 __xla_int8_comparison_param_3,
.param .u64 __xla_int8_comparison_param_4
)
{
.reg .pred %p<10>;
.reg .f32 %f<42>;
.reg .b32 %r<23>;
.reg .b64 %rd<12>;
ld.param.u64 %rd2, [__xla_int8_comparison_param_0];
ld.param.u64 %rd3, [__xla_int8_comparison_param_1];
ld.param.f32 %f5, [__xla_int8_comparison_param_2];
ld.param.u64 %rd4, [__xla_int8_comparison_param_3];
ld.param.u64 %rd5, [__xla_int8_comparison_param_4];
cvta.to.global.u64 %rd1, %rd5;
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %ctaid.x;
mov.u32 %r6, %tid.x;
mad.lo.s32 %r1, %r4, %r5, %r6;
cvt.s64.s32 %rd6, %r1;
setp.ge.u64 %p1, %rd6, %rd4;
@%p1 bra BB3_13;
cvta.to.global.u64 %rd7, %rd2;
mul.wide.s32 %rd8, %r1, 4;
add.s64 %rd9, %rd7, %rd8;
cvta.to.global.u64 %rd10, %rd3;
add.s64 %rd11, %rd10, %rd8;
ld.global.u32 %r2, [%rd9];
cvt.s32.s8 %r7, %r2;
cvt.rn.f32.s32 %f6, %r7;
ld.global.u32 %r3, [%rd11];
cvt.s32.s8 %r8, %r3;
cvt.rn.f32.s32 %f7, %r8;
sub.f32 %f8, %f6, %f7;
abs.f32 %f9, %f8;
abs.f32 %f10, %f6;
abs.f32 %f11, %f7;
max.f32 %f12, %f10, %f11;
add.f32 %f13, %f12, 0f3F800000;
div.rn.f32 %f1, %f9, %f13;
setp.gt.f32 %p2, %f1, %f5;
@%p2 bra BB3_3;
abs.f32 %f14, %f1;
setp.le.f32 %p3, %f14, 0f7F800000;
@%p3 bra BB3_4;
BB3_3:
atom.global.add.u32 %r9, [%rd1], 1;
BB3_4:
shr.u32 %r10, %r3, 8;
shr.u32 %r11, %r2, 8;
cvt.s32.s8 %r12, %r11;
cvt.rn.f32.s32 %f15, %r12;
cvt.s32.s8 %r13, %r10;
cvt.rn.f32.s32 %f16, %r13;
sub.f32 %f17, %f15, %f16;
abs.f32 %f18, %f17;
abs.f32 %f19, %f15;
abs.f32 %f20, %f16;
max.f32 %f21, %f19, %f20;
add.f32 %f22, %f21, 0f3F800000;
div.rn.f32 %f2, %f18, %f22;
setp.gt.f32 %p4, %f2, %f5;
@%p4 bra BB3_6;
abs.f32 %f23, %f2;
setp.le.f32 %p5, %f23, 0f7F800000;
@%p5 bra BB3_7;
BB3_6:
atom.global.add.u32 %r14, [%rd1], 1;
BB3_7:
shr.u32 %r15, %r3, 16;
shr.u32 %r16, %r2, 16;
cvt.s32.s8 %r17, %r16;
cvt.rn.f32.s32 %f24, %r17;
cvt.s32.s8 %r18, %r15;
cvt.rn.f32.s32 %f25, %r18;
sub.f32 %f26, %f24, %f25;
abs.f32 %f27, %f26;
abs.f32 %f28, %f24;
abs.f32 %f29, %f25;
max.f32 %f30, %f28, %f29;
add.f32 %f31, %f30, 0f3F800000;
div.rn.f32 %f3, %f27, %f31;
setp.gt.f32 %p6, %f3, %f5;
@%p6 bra BB3_9;
abs.f32 %f32, %f3;
setp.le.f32 %p7, %f32, 0f7F800000;
@%p7 bra BB3_10;
BB3_9:
atom.global.add.u32 %r19, [%rd1], 1;
BB3_10:
shr.s32 %r20, %r2, 24;
cvt.rn.f32.s32 %f33, %r20;
shr.s32 %r21, %r3, 24;
cvt.rn.f32.s32 %f34, %r21;
sub.f32 %f35, %f33, %f34;
abs.f32 %f36, %f35;
abs.f32 %f37, %f33;
abs.f32 %f38, %f34;
max.f32 %f39, %f37, %f38;
add.f32 %f40, %f39, 0f3F800000;
div.rn.f32 %f4, %f36, %f40;
setp.gt.f32 %p8, %f4, %f5;
@%p8 bra BB3_12;
abs.f32 %f41, %f4;
setp.le.f32 %p9, %f41, 0f7F800000;
@%p9 bra BB3_13;
BB3_12:
atom.global.add.u32 %r22, [%rd1], 1;
BB3_13:
ret;
}
)";
@ -405,11 +619,13 @@ StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
const auto canonicalize = [](ComparisonType a) -> ComparisonType {
if (std::is_same<ElementType, Eigen::half>::value && a) {
constexpr ComparisonType kMaxFp16Value = 65505.;
constexpr ComparisonType kMaxFp16Value =
std::is_same<ElementType, Eigen::half>::value ? 65505. : 0;
if (std::isnan(a)) {
return a;
}
return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
return std::max(static_cast<ComparisonType>(-kMaxFp16Value),
static_cast<ComparisonType>(std::min(a, kMaxFp16Value)));
}
return a;
};
@ -472,6 +688,9 @@ StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
case xla::F64:
return CompareEqualParameterized<double, double>(
stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
case xla::S8:
return CompareEqualParameterized<int8, float>(
stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
default:
return Unimplemented("Unimplemented element type");
}

View File

@ -178,6 +178,13 @@ TEST_F(BufferComparatorTest, TestNumbers) {
EXPECT_TRUE(CompareEqualFloatBuffers<double>({0.9}, {1}));
EXPECT_TRUE(CompareEqualFloatBuffers<double>({9}, {10}));
EXPECT_TRUE(CompareEqualFloatBuffers<double>({10}, {9}));
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({200}, {201}));
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({0}, {10}));
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({9}, {10}));
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({90}, {100}));
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({100}, {90}));
EXPECT_FALSE(CompareEqualFloatBuffers<int8>({-128}, {127}));
}
TEST_F(BufferComparatorTest, TestMultiple) {
@ -231,6 +238,23 @@ TEST_F(BufferComparatorTest, TestMultiple) {
rhs[i] = 0;
}
}
{
EXPECT_TRUE(CompareEqualFloatBuffers<int8>({20, 30, 40, 50, 60},
{21, 31, 41, 51, 61}));
std::vector<float> lhs(200);
std::vector<float> rhs(200);
for (int i = 0; i < 200; i++) {
EXPECT_TRUE(CompareEqualFloatBuffers<int8>(lhs, rhs))
<< "should be the same at index " << i;
lhs[i] = 3;
rhs[i] = 5;
EXPECT_FALSE(CompareEqualFloatBuffers<int8>(lhs, rhs))
<< "should be the different at index " << i;
lhs[i] = 0;
rhs[i] = 0;
}
}
}
} // namespace