From 20ec7a8ae2a142955d219bc696015053b442f42e Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Mon, 26 Aug 2019 14:31:37 -0400 Subject: [PATCH 1/4] Add S8 support. --- tensorflow/compiler/xla/service/gpu/buffer_comparator.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 30108315e4d..b3a3b1f9c76 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -405,11 +405,13 @@ StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs, const auto canonicalize = [](ComparisonType a) -> ComparisonType { if (std::is_same::value && a) { - constexpr ComparisonType kMaxFp16Value = 65505.; + constexpr ComparisonType kMaxFp16Value = + std::is_same::value ? 65505. : 0; if (std::isnan(a)) { return a; } - return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value)); + return std::max(static_cast(-kMaxFp16Value), + static_cast(std::min(a, kMaxFp16Value))); } return a; }; @@ -472,6 +474,9 @@ StatusOr BufferComparator::CompareEqual(se::Stream* stream, case xla::F64: return CompareEqualParameterized( stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison"); + case xla::S8: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_int8_comparison"); default: return Unimplemented("Unimplemented element type"); } From 0696cc9a62785e861e2a8c2527035a39e30cb3b0 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Tue, 27 Aug 2019 00:01:07 -0400 Subject: [PATCH 2/4] Add ptx for __xla_int8_comparison and int8 tests. --- .../xla/service/gpu/buffer_comparator.cc | 610 ++++++++++++------ .../xla/service/gpu/buffer_comparator_test.cc | 23 + 2 files changed, 426 insertions(+), 207 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index b3a3b1f9c76..6c8a71b884a 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -102,234 +102,430 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } +// +// __global__ void __xla_int8_comparison(int8* buffer_a, int8* buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int* mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// float elem_a = __int2float_rn(buffer_a[idx]); +// float elem_b = __int2float_rn(buffer_b[idx]); +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// } // } // end extern declaration. static const char* buffer_compare_ptx = R"( -.version 4.2 +.version 6.4 .target sm_30 .address_size 64 + // .globl __xla_fp16_comparison + .visible .entry __xla_fp16_comparison( - .param .u64 __xla_fp16_comparison_param_0, - .param .u64 __xla_fp16_comparison_param_1, - .param .f32 __xla_fp16_comparison_param_2, - .param .u64 __xla_fp16_comparison_param_3, - .param .u64 __xla_fp16_comparison_param_4 + .param .u64 __xla_fp16_comparison_param_0, + .param .u64 __xla_fp16_comparison_param_1, + .param .f32 __xla_fp16_comparison_param_2, + .param .u64 __xla_fp16_comparison_param_3, + .param .u64 __xla_fp16_comparison_param_4 ) { - .reg .pred %p<10>; - .reg .b16 %rs<3>; - .reg .f32 %f<20>; - .reg .b32 %r<6>; - .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp16_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB7_4; - ld.param.u64 %rd5, [__xla_fp16_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp16_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 1; - add.s64 %rd10, %rd3, %rd9; - ld.global.u16 %rs1, [%rd10]; - // begin inline asm - { cvt.f32.f16 %f6, %rs1;} + .reg .pred %p<9>; + .reg .b16 %rs<3>; + .reg .f32 %f<28>; + .reg .b32 %r<6>; + .reg .b64 %rd<12>; - // end inline asm - add.s64 %rd11, %rd2, %rd9; - ld.global.u16 %rs2, [%rd11]; - // begin inline asm - { cvt.f32.f16 %f7, %rs2;} - // end inline asm - abs.f32 %f8, %f6; - setp.gtu.f32 %p2, %f8, 0f7F800000; - min.f32 %f9, %f6, 0f477FE100; - max.f32 %f10, %f9, 0fC77FE100; - selp.f32 %f1, %f6, %f10, %p2; - abs.f32 %f11, %f7; - setp.gtu.f32 %p3, %f11, 0f7F800000; - min.f32 %f12, %f7, 0f477FE100; - max.f32 %f13, %f12, 0fC77FE100; - selp.f32 %f2, %f7, %f13, %p3; - abs.f32 %f3, %f1; - setp.gtu.f32 %p4, %f3, 0f7F800000; - abs.f32 %f4, %f2; - setp.gtu.f32 %p5, %f4, 0f7F800000; - and.pred %p6, %p4, %p5; - @%p6 bra LBB7_4; - ld.param.f32 %f5, [__xla_fp16_comparison_param_2]; - sub.f32 %f14, %f1, %f2; - abs.f32 %f15, %f14; - max.f32 %f16, %f3, %f4; - add.f32 %f17, %f16, 0f3F800000; - div.rn.f32 %f18, %f15, %f17; - setp.leu.f32 %p7, %f18, %f5; - abs.f32 %f19, %f18; - setp.le.f32 %p8, %f19, 0f7F800000; - and.pred %p9, %p7, %p8; - @%p9 bra LBB7_4; - ld.param.u64 %rd6, [__xla_fp16_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r5, [%rd1], 1; -LBB7_4: - ret; + ld.param.u64 %rd1, [__xla_fp16_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp16_comparison_param_1]; + ld.param.f32 %f10, [__xla_fp16_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp16_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp16_comparison_param_4]; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r1, %r2, %r3, %r4; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB0_9; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 2; + add.s64 %rd8, %rd6, %rd7; + ld.global.u16 %rs1, [%rd8]; + // inline asm + { cvt.f32.f16 %f26, %rs1;} + + // inline asm + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.u16 %rs2, [%rd10]; + // inline asm + { cvt.f32.f16 %f27, %rs2;} + + // inline asm + abs.f32 %f13, %f26; + setp.gtu.f32 %p2, %f13, 0f7F800000; + @%p2 bra BB0_3; + + mov.f32 %f14, 0f477FE100; + min.f32 %f15, %f26, %f14; + mov.f32 %f16, 0fC77FE100; + max.f32 %f26, %f16, %f15; + +BB0_3: + abs.f32 %f17, %f27; + setp.gtu.f32 %p3, %f17, 0f7F800000; + @%p3 bra BB0_5; + + mov.f32 %f18, 0f477FE100; + min.f32 %f19, %f27, %f18; + mov.f32 %f20, 0fC77FE100; + max.f32 %f27, %f20, %f19; + +BB0_5: + abs.f32 %f7, %f26; + setp.gtu.f32 %p4, %f7, 0f7F800000; + abs.f32 %f8, %f27; + setp.gtu.f32 %p5, %f8, 0f7F800000; + and.pred %p6, %p4, %p5; + @%p6 bra BB0_9; + + sub.f32 %f21, %f26, %f27; + abs.f32 %f22, %f21; + max.f32 %f23, %f7, %f8; + add.f32 %f24, %f23, 0f3F800000; + div.rn.f32 %f9, %f22, %f24; + setp.gt.f32 %p7, %f9, %f10; + @%p7 bra BB0_8; + + abs.f32 %f25, %f9; + setp.le.f32 %p8, %f25, 0f7F800000; + @%p8 bra BB0_9; + +BB0_8: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r5, [%rd11], 1; + +BB0_9: + ret; } - // .globl __xla_fp32_comparison + + // .globl __xla_fp32_comparison .visible .entry __xla_fp32_comparison( - .param .u64 __xla_fp32_comparison_param_0, - .param .u64 __xla_fp32_comparison_param_1, - .param .f32 __xla_fp32_comparison_param_2, - .param .u64 __xla_fp32_comparison_param_3, - .param .u64 __xla_fp32_comparison_param_4 + .param .u64 __xla_fp32_comparison_param_0, + .param .u64 __xla_fp32_comparison_param_1, + .param .f32 __xla_fp32_comparison_param_2, + .param .u64 __xla_fp32_comparison_param_3, + .param .u64 __xla_fp32_comparison_param_4 ) { - .reg .pred %p<12>; - .reg .f32 %f<12>; - .reg .b32 %r<9>; - .reg .b64 %rd<12>; + .reg .pred %p<10>; + .reg .b16 %rs<3>; + .reg .f32 %f<13>; + .reg .b32 %r<10>; + .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp32_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB8_6; - ld.param.u64 %rd5, [__xla_fp32_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp32_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 2; - add.s64 %rd10, %rd3, %rd9; - ld.global.f32 %f1, [%rd10]; - add.s64 %rd11, %rd2, %rd9; - ld.global.f32 %f2, [%rd11]; - abs.f32 %f3, %f1; - setp.gtu.f32 %p2, %f3, 0f7F800000; - abs.f32 %f4, %f2; - setp.gtu.f32 %p3, %f4, 0f7F800000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB8_6; - setp.neu.f32 %p5, %f3, 0f7F800000; - setp.neu.f32 %p6, %f4, 0f7F800000; - or.pred %p7, %p5, %p6; - @%p7 bra LBB8_4; - mov.b32 %r5, %f1; - mov.b32 %r6, %f2; - xor.b32 %r7, %r6, %r5; - setp.gt.s32 %p8, %r7, -1; - @%p8 bra LBB8_6; -LBB8_4: - ld.param.f32 %f5, [__xla_fp32_comparison_param_2]; - sub.f32 %f6, %f1, %f2; - abs.f32 %f7, %f6; - max.f32 %f8, %f3, %f4; - add.f32 %f9, %f8, 0f3F800000; - div.rn.f32 %f10, %f7, %f9; - setp.leu.f32 %p9, %f10, %f5; - abs.f32 %f11, %f10; - setp.le.f32 %p10, %f11, 0f7F800000; - and.pred %p11, %p9, %p10; - @%p11 bra LBB8_6; - ld.param.u64 %rd6, [__xla_fp32_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r8, [%rd1], 1; -LBB8_6: - ret; + ld.param.u64 %rd1, [__xla_fp32_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp32_comparison_param_1]; + ld.param.f32 %f6, [__xla_fp32_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp32_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp32_comparison_param_4]; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r1, %r2, %r3, %r4; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB1_8; + + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 4; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.f32 %f1, [%rd10]; + ld.global.f32 %f2, [%rd8]; + abs.f32 %f3, %f2; + setp.le.f32 %p2, %f3, 0f7F800000; + @%p2 bra BB1_3; + + abs.f32 %f7, %f1; + setp.gtu.f32 %p3, %f7, 0f7F800000; + @%p3 bra BB1_8; + +BB1_3: + setp.neu.f32 %p4, %f3, 0f7F800000; + abs.f32 %f4, %f1; + setp.neu.f32 %p5, %f4, 0f7F800000; + or.pred %p6, %p4, %p5; + @%p6 bra BB1_5; + + mov.b32 %r5, %f2; + shr.u32 %r6, %r5, 31; + cvt.u16.u32 %rs1, %r6; + mov.b32 %r7, %f1; + shr.u32 %r8, %r7, 31; + cvt.u16.u32 %rs2, %r8; + setp.eq.s16 %p7, %rs1, %rs2; + @%p7 bra BB1_8; + +BB1_5: + sub.f32 %f8, %f2, %f1; + abs.f32 %f9, %f8; + max.f32 %f10, %f3, %f4; + add.f32 %f11, %f10, 0f3F800000; + div.rn.f32 %f5, %f9, %f11; + setp.gt.f32 %p8, %f5, %f6; + @%p8 bra BB1_7; + + abs.f32 %f12, %f5; + setp.le.f32 %p9, %f12, 0f7F800000; + @%p9 bra BB1_8; + +BB1_7: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r9, [%rd11], 1; + +BB1_8: + ret; } - // .globl __xla_fp64_comparison + + // .globl __xla_fp64_comparison .visible .entry __xla_fp64_comparison( - .param .u64 __xla_fp64_comparison_param_0, - .param .u64 __xla_fp64_comparison_param_1, - .param .f32 __xla_fp64_comparison_param_2, - .param .u64 __xla_fp64_comparison_param_3, - .param .u64 __xla_fp64_comparison_param_4 + .param .u64 __xla_fp64_comparison_param_0, + .param .u64 __xla_fp64_comparison_param_1, + .param .f32 __xla_fp64_comparison_param_2, + .param .u64 __xla_fp64_comparison_param_3, + .param .u64 __xla_fp64_comparison_param_4 ) { - .reg .pred %p<16>; - .reg .f32 %f<2>; - .reg .b32 %r<13>; - .reg .f64 %fd<12>; - .reg .b64 %rd<12>; + .reg .pred %p<11>; + .reg .b16 %rs<3>; + .reg .f32 %f<2>; + .reg .b32 %r<14>; + .reg .f64 %fd<13>; + .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp64_comparison_param_3]; - mov.u32 %r2, %tid.x; - mov.u32 %r3, %ctaid.x; - mov.u32 %r4, %ntid.x; - mad.lo.s32 %r5, %r4, %r3, %r2; - cvt.s64.s32 %rd4, %r5; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB9_6; - ld.param.u64 %rd5, [__xla_fp64_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp64_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 3; - add.s64 %rd10, %rd3, %rd9; - ld.global.f64 %fd1, [%rd10]; - add.s64 %rd11, %rd2, %rd9; - ld.global.f64 %fd2, [%rd11]; - abs.f64 %fd3, %fd1; - setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000; - abs.f64 %fd4, %fd2; - setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB9_6; - { - .reg .b32 %temp; - mov.b64 {%r6, %temp}, %fd1; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r1}, %fd1; - } - and.b32 %r7, %r1, 2147483647; - setp.ne.s32 %p5, %r7, 2146435072; - setp.ne.s32 %p6, %r6, 0; - or.pred %p7, %p6, %p5; - @%p7 bra LBB9_4; - { - .reg .b32 %temp; - mov.b64 {%r8, %temp}, %fd2; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r9}, %fd2; - } - and.b32 %r10, %r9, 2147483647; - setp.eq.s32 %p8, %r10, 2146435072; - setp.eq.s32 %p9, %r8, 0; - and.pred %p10, %p8, %p9; - xor.b32 %r11, %r9, %r1; - setp.gt.s32 %p11, %r11, -1; - and.pred %p12, %p11, %p10; - @%p12 bra LBB9_6; -LBB9_4: - ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; - sub.f64 %fd5, %fd1, %fd2; - abs.f64 %fd6, %fd5; - max.f64 %fd7, %fd3, %fd4; - add.f64 %fd8, %fd7, 0d3FF0000000000000; - div.rn.f64 %fd9, %fd6, %fd8; - cvt.f64.f32 %fd10, %f1; - setp.leu.f64 %p13, %fd9, %fd10; - abs.f64 %fd11, %fd9; - setp.le.f64 %p14, %fd11, 0d7FF0000000000000; - and.pred %p15, %p13, %p14; - @%p15 bra LBB9_6; - ld.param.u64 %rd6, [__xla_fp64_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r12, [%rd1], 1; -LBB9_6: - ret; + + ld.param.u64 %rd1, [__xla_fp64_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp64_comparison_param_1]; + ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp64_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp64_comparison_param_4]; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %tid.x; + mad.lo.s32 %r1, %r4, %r5, %r6; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB2_11; + + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.f64 %fd1, [%rd10]; + ld.global.f64 %fd2, [%rd8]; + abs.f64 %fd3, %fd2; + setp.le.f64 %p2, %fd3, 0d7FF0000000000000; + @%p2 bra BB2_3; + + abs.f64 %fd5, %fd1; + setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000; + @%p3 bra BB2_11; + +BB2_3: + { + .reg .b32 %temp; + mov.b64 {%temp, %r2}, %fd2; + } + and.b32 %r7, %r2, 2147483647; + setp.ne.s32 %p4, %r7, 2146435072; + @%p4 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%r8, %temp}, %fd2; + } + setp.ne.s32 %p5, %r8, 0; + @%p5 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%temp, %r3}, %fd1; + } + and.b32 %r9, %r3, 2147483647; + setp.ne.s32 %p6, %r9, 2146435072; + @%p6 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%r10, %temp}, %fd1; + } + setp.ne.s32 %p7, %r10, 0; + @%p7 bra BB2_8; + + shr.u32 %r11, %r2, 31; + cvt.u16.u32 %rs1, %r11; + shr.u32 %r12, %r3, 31; + cvt.u16.u32 %rs2, %r12; + setp.eq.s16 %p8, %rs1, %rs2; + @%p8 bra BB2_11; + +BB2_8: + sub.f64 %fd6, %fd2, %fd1; + abs.f64 %fd7, %fd6; + abs.f64 %fd8, %fd1; + max.f64 %fd9, %fd3, %fd8; + add.f64 %fd10, %fd9, 0d3FF0000000000000; + div.rn.f64 %fd4, %fd7, %fd10; + cvt.f64.f32 %fd11, %f1; + setp.gt.f64 %p9, %fd4, %fd11; + @%p9 bra BB2_10; + + abs.f64 %fd12, %fd4; + setp.le.f64 %p10, %fd12, 0d7FF0000000000000; + @%p10 bra BB2_11; + +BB2_10: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r13, [%rd11], 1; + +BB2_11: + ret; +} + + // .globl __xla_int8_comparison +.visible .entry __xla_int8_comparison( + .param .u64 __xla_int8_comparison_param_0, + .param .u64 __xla_int8_comparison_param_1, + .param .f32 __xla_int8_comparison_param_2, + .param .u64 __xla_int8_comparison_param_3, + .param .u64 __xla_int8_comparison_param_4 +) +{ + .reg .pred %p<10>; + .reg .f32 %f<42>; + .reg .b32 %r<19>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_int8_comparison_param_0]; + ld.param.u64 %rd3, [__xla_int8_comparison_param_1]; + ld.param.f32 %f5, [__xla_int8_comparison_param_2]; + ld.param.u64 %rd4, [__xla_int8_comparison_param_3]; + ld.param.u64 %rd5, [__xla_int8_comparison_param_4]; + cvta.to.global.u64 %rd1, %rd5; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %tid.x; + mad.lo.s32 %r1, %r4, %r5, %r6; + cvt.s64.s32 %rd6, %r1; + setp.ge.u64 %p1, %rd6, %rd4; + @%p1 bra BB3_13; + + cvta.to.global.u64 %rd7, %rd2; + mul.wide.s32 %rd8, %r1, 4; + add.s64 %rd9, %rd7, %rd8; + cvta.to.global.u64 %rd10, %rd3; + add.s64 %rd11, %rd10, %rd8; + ld.global.u32 %r2, [%rd9]; + and.b32 %r7, %r2, 255; + cvt.rn.f32.s32 %f6, %r7; + ld.global.u32 %r3, [%rd11]; + and.b32 %r8, %r3, 255; + cvt.rn.f32.s32 %f7, %r8; + sub.f32 %f8, %f6, %f7; + abs.f32 %f9, %f8; + abs.f32 %f10, %f6; + abs.f32 %f11, %f7; + max.f32 %f12, %f10, %f11; + add.f32 %f13, %f12, 0f3F800000; + div.rn.f32 %f1, %f9, %f13; + setp.gt.f32 %p2, %f1, %f5; + @%p2 bra BB3_3; + + abs.f32 %f14, %f1; + setp.le.f32 %p3, %f14, 0f7F800000; + @%p3 bra BB3_4; + +BB3_3: + atom.global.add.u32 %r9, [%rd1], 1; + +BB3_4: + bfe.u32 %r10, %r2, 8, 8; + cvt.rn.f32.s32 %f15, %r10; + bfe.u32 %r11, %r3, 8, 8; + cvt.rn.f32.s32 %f16, %r11; + sub.f32 %f17, %f15, %f16; + abs.f32 %f18, %f17; + abs.f32 %f19, %f15; + abs.f32 %f20, %f16; + max.f32 %f21, %f19, %f20; + add.f32 %f22, %f21, 0f3F800000; + div.rn.f32 %f2, %f18, %f22; + setp.gt.f32 %p4, %f2, %f5; + @%p4 bra BB3_6; + + abs.f32 %f23, %f2; + setp.le.f32 %p5, %f23, 0f7F800000; + @%p5 bra BB3_7; + +BB3_6: + atom.global.add.u32 %r12, [%rd1], 1; + +BB3_7: + bfe.u32 %r13, %r2, 16, 8; + cvt.rn.f32.s32 %f24, %r13; + bfe.u32 %r14, %r3, 16, 8; + cvt.rn.f32.s32 %f25, %r14; + sub.f32 %f26, %f24, %f25; + abs.f32 %f27, %f26; + abs.f32 %f28, %f24; + abs.f32 %f29, %f25; + max.f32 %f30, %f28, %f29; + add.f32 %f31, %f30, 0f3F800000; + div.rn.f32 %f3, %f27, %f31; + setp.gt.f32 %p6, %f3, %f5; + @%p6 bra BB3_9; + + abs.f32 %f32, %f3; + setp.le.f32 %p7, %f32, 0f7F800000; + @%p7 bra BB3_10; + +BB3_9: + atom.global.add.u32 %r15, [%rd1], 1; + +BB3_10: + shr.u32 %r16, %r3, 24; + shr.u32 %r17, %r2, 24; + cvt.rn.f32.s32 %f33, %r17; + cvt.rn.f32.s32 %f34, %r16; + sub.f32 %f35, %f33, %f34; + abs.f32 %f36, %f35; + abs.f32 %f37, %f33; + abs.f32 %f38, %f34; + max.f32 %f39, %f37, %f38; + add.f32 %f40, %f39, 0f3F800000; + div.rn.f32 %f4, %f36, %f40; + setp.gt.f32 %p8, %f4, %f5; + @%p8 bra BB3_12; + + abs.f32 %f41, %f4; + setp.le.f32 %p9, %f41, 0f7F800000; + @%p9 bra BB3_13; + +BB3_12: + atom.global.add.u32 %r18, [%rd1], 1; + +BB3_13: + ret; } )"; @@ -475,7 +671,7 @@ StatusOr BufferComparator::CompareEqual(se::Stream* stream, return CompareEqualParameterized( stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison"); case xla::S8: - return CompareEqualParameterized( + return CompareEqualParameterized( stream, lhs, rhs, shape_, config_, "__xla_int8_comparison"); default: return Unimplemented("Unimplemented element type"); diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index 139e4204304..4ba7296304b 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -178,6 +178,12 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); + + EXPECT_TRUE(CompareEqualFloatBuffers({200}, {201})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); + EXPECT_TRUE(CompareEqualFloatBuffers({100}, {90})); } TEST_F(BufferComparatorTest, TestMultiple) { @@ -231,6 +237,23 @@ TEST_F(BufferComparatorTest, TestMultiple) { rhs[i] = 0; } } + + { + EXPECT_TRUE(CompareEqualFloatBuffers( + {20, 30, 40, 50, 60}, {21, 31, 41, 51, 61})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } } } // namespace From 25967fc3aa3aa98c5b80b5f9989487cb694e105d Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Tue, 27 Aug 2019 16:35:17 -0400 Subject: [PATCH 3/4] Update int8 comparison kernel and test for negative values. --- .../xla/service/gpu/buffer_comparator.cc | 84 +++++++++++-------- .../xla/service/gpu/buffer_comparator_test.cc | 1 + 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 6c8a71b884a..e9a2f86aabc 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -34,7 +34,7 @@ namespace gpu { static constexpr double kTolerance = 0.1f; -// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length +// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length // buffer_length where the relative error does not exceed the passed // rel_error_threshold. Write the number of mismatches into out parameter // mismatch_count. @@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f; // // #include // extern "C" { // avoid name mangling -// __device__ float canonicalize(float input) { +// __device__ float canonicalize_fp16(float input) { // // All fp16 infinities are treated as 65505 or -65505, in order to avoid // // differences due to overflows. // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); // } -// + +// __device__ float extract_int8(int pack) { +// // Extract the lower 8 bits from pack and convert it to float +// const unsigned int bit_mask = 0xff; +// unsigned int bits = pack & bit_mask; +// char* int8_ptr = (char*)&bits; +// return __int2float_rn(*int8_ptr); +// } + // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f; // if (idx >= buffer_length) return; // float elem_a = __half2float(buffer_a[idx]); // float elem_b = __half2float(buffer_b[idx]); -// elem_a = canonicalize(elem_a); -// elem_b = canonicalize(elem_b); +// elem_a = canonicalize_fp16(elem_a); +// elem_b = canonicalize_fp16(elem_b); // if (isnan(elem_a) && isnan(elem_b)) return; // float rel_error = abs(elem_a - elem_b) // / (max(abs(elem_a), abs(elem_b)) + 1); // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -102,19 +110,25 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// -// __global__ void __xla_int8_comparison(int8* buffer_a, int8* buffer_b, + +// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, // int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; // if (idx >= buffer_length) return; -// float elem_a = __int2float_rn(buffer_a[idx]); -// float elem_b = __int2float_rn(buffer_b[idx]); -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); +// int pack_a = buffer_a[idx]; +// int pack_b = buffer_b[idx]; +// for(int i = 0; i < 4; ++i) { +// float elem_a = extract_int8(pack_a); +// float elem_b = extract_int8(pack_b); +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// pack_a >>= 8; +// pack_b >>= 8; +// } // } // } // end extern declaration. static const char* buffer_compare_ptx = R"( @@ -412,7 +426,7 @@ BB2_11: { .reg .pred %p<10>; .reg .f32 %f<42>; - .reg .b32 %r<19>; + .reg .b32 %r<23>; .reg .b64 %rd<12>; @@ -436,10 +450,10 @@ BB2_11: cvta.to.global.u64 %rd10, %rd3; add.s64 %rd11, %rd10, %rd8; ld.global.u32 %r2, [%rd9]; - and.b32 %r7, %r2, 255; + cvt.s32.s8 %r7, %r2; cvt.rn.f32.s32 %f6, %r7; ld.global.u32 %r3, [%rd11]; - and.b32 %r8, %r3, 255; + cvt.s32.s8 %r8, %r3; cvt.rn.f32.s32 %f7, %r8; sub.f32 %f8, %f6, %f7; abs.f32 %f9, %f8; @@ -459,10 +473,12 @@ BB3_3: atom.global.add.u32 %r9, [%rd1], 1; BB3_4: - bfe.u32 %r10, %r2, 8, 8; - cvt.rn.f32.s32 %f15, %r10; - bfe.u32 %r11, %r3, 8, 8; - cvt.rn.f32.s32 %f16, %r11; + shr.u32 %r10, %r3, 8; + shr.u32 %r11, %r2, 8; + cvt.s32.s8 %r12, %r11; + cvt.rn.f32.s32 %f15, %r12; + cvt.s32.s8 %r13, %r10; + cvt.rn.f32.s32 %f16, %r13; sub.f32 %f17, %f15, %f16; abs.f32 %f18, %f17; abs.f32 %f19, %f15; @@ -478,13 +494,15 @@ BB3_4: @%p5 bra BB3_7; BB3_6: - atom.global.add.u32 %r12, [%rd1], 1; + atom.global.add.u32 %r14, [%rd1], 1; BB3_7: - bfe.u32 %r13, %r2, 16, 8; - cvt.rn.f32.s32 %f24, %r13; - bfe.u32 %r14, %r3, 16, 8; - cvt.rn.f32.s32 %f25, %r14; + shr.u32 %r15, %r3, 16; + shr.u32 %r16, %r2, 16; + cvt.s32.s8 %r17, %r16; + cvt.rn.f32.s32 %f24, %r17; + cvt.s32.s8 %r18, %r15; + cvt.rn.f32.s32 %f25, %r18; sub.f32 %f26, %f24, %f25; abs.f32 %f27, %f26; abs.f32 %f28, %f24; @@ -500,13 +518,13 @@ BB3_7: @%p7 bra BB3_10; BB3_9: - atom.global.add.u32 %r15, [%rd1], 1; + atom.global.add.u32 %r19, [%rd1], 1; BB3_10: - shr.u32 %r16, %r3, 24; - shr.u32 %r17, %r2, 24; - cvt.rn.f32.s32 %f33, %r17; - cvt.rn.f32.s32 %f34, %r16; + shr.s32 %r20, %r2, 24; + cvt.rn.f32.s32 %f33, %r20; + shr.s32 %r21, %r3, 24; + cvt.rn.f32.s32 %f34, %r21; sub.f32 %f35, %f33, %f34; abs.f32 %f36, %f35; abs.f32 %f37, %f33; @@ -522,7 +540,7 @@ BB3_10: @%p9 bra BB3_13; BB3_12: - atom.global.add.u32 %r18, [%rd1], 1; + atom.global.add.u32 %r22, [%rd1], 1; BB3_13: ret; diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index 4ba7296304b..77753c1d093 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -184,6 +184,7 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); EXPECT_TRUE(CompareEqualFloatBuffers({100}, {90})); + EXPECT_FALSE(CompareEqualFloatBuffers({-128}, {127})); } TEST_F(BufferComparatorTest, TestMultiple) { From 6e14ed49f60af72c2c29826cf5e984f3e8f48b44 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Tue, 27 Aug 2019 17:16:18 -0400 Subject: [PATCH 4/4] Uniquify function names further. --- .../compiler/xla/service/gpu/buffer_comparator.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index e9a2f86aabc..29f8677bfce 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -46,13 +46,13 @@ static constexpr double kTolerance = 0.1f; // // #include // extern "C" { // avoid name mangling -// __device__ float canonicalize_fp16(float input) { +// __device__ float __xla_buffer_comparator_canonicalize(float input) { // // All fp16 infinities are treated as 65505 or -65505, in order to avoid // // differences due to overflows. // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); // } -// __device__ float extract_int8(int pack) { +// __device__ float __xla_buffer_comparator_extract_int8(int pack) { // // Extract the lower 8 bits from pack and convert it to float // const unsigned int bit_mask = 0xff; // unsigned int bits = pack & bit_mask; @@ -68,8 +68,8 @@ static constexpr double kTolerance = 0.1f; // if (idx >= buffer_length) return; // float elem_a = __half2float(buffer_a[idx]); // float elem_b = __half2float(buffer_b[idx]); -// elem_a = canonicalize_fp16(elem_a); -// elem_b = canonicalize_fp16(elem_b); +// elem_a = __xla_buffer_comparator_canonicalize(elem_a); +// elem_b = __xla_buffer_comparator_canonicalize(elem_b); // if (isnan(elem_a) && isnan(elem_b)) return; // float rel_error = abs(elem_a - elem_b) // / (max(abs(elem_a), abs(elem_b)) + 1); @@ -120,8 +120,8 @@ static constexpr double kTolerance = 0.1f; // int pack_a = buffer_a[idx]; // int pack_b = buffer_b[idx]; // for(int i = 0; i < 4; ++i) { -// float elem_a = extract_int8(pack_a); -// float elem_b = extract_int8(pack_b); +// float elem_a = __xla_buffer_comparator_extract_int8(pack_a); +// float elem_b = __xla_buffer_comparator_extract_int8(pack_b); // float rel_error = abs(elem_a - elem_b) // / (max(abs(elem_a), abs(elem_b)) + 1); // if (rel_error > rel_error_threshold || isnan(rel_error))