Use a PTX blob to compare two buffers after the convolutions for correctness.
Extracts CompilePtxOrGetCached into stream_executor_util, and fixes a misc bug in a comparison function on the host, and a memory leak in test.. The compilation of HLO previously used to take about 30% of autotuning time on some benchmarks (the actual gain in time is smaller, due to computations running concurrently on GPU at the same time). PiperOrigin-RevId: 246897145
This commit is contained in:
parent
805dbd25af
commit
ae53fb3375
@ -1,7 +1,6 @@
|
||||
# Description:
|
||||
# GPU-specific components in XLA service implementation.
|
||||
|
||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
@ -537,6 +536,7 @@ tf_cc_test(
|
||||
":redzone_allocator",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
@ -1160,30 +1160,30 @@ cc_library(
|
||||
hdrs = ["buffer_comparator.h"],
|
||||
deps = [
|
||||
":gpu_executable",
|
||||
":partition_assignment",
|
||||
":stream_executor_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
tf_cc_test(
|
||||
name = "buffer_comparator_test",
|
||||
srcs = ["buffer_comparator_test.cc"],
|
||||
backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":buffer_comparator",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -15,233 +15,387 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/kernel.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
static constexpr double kTolerance = 0.1f;
|
||||
|
||||
static StatusOr<string> GetCompHloText(const Shape& shape) {
|
||||
// Implements the textual format of the comparison routine, as it's more
|
||||
// readable.
|
||||
//
|
||||
// This text template takes three substitution parameters:
|
||||
// ${ORIG_TYPE}: buffer element type.
|
||||
// ${CMP_TYPE}: intermediate element type for calculating numeric differences.
|
||||
// ${SIZE}: number of elements.
|
||||
// ${CLAMP_TO}: Clamp the value to [-$CLAMP_TO, $CLAMP_TO].
|
||||
static constexpr char kCompHloText[] = R"(
|
||||
HloModule Compare_${ORIG_TYPE}_${CMP_TYPE}_${SIZE}_${CLAMP_TO}
|
||||
|
||||
Max {
|
||||
%lhs = ${CMP_TYPE}[] parameter(0)
|
||||
%rhs = ${CMP_TYPE}[] parameter(1)
|
||||
ROOT %max = ${CMP_TYPE}[] maximum(%lhs, %rhs)
|
||||
}
|
||||
|
||||
Canonicalize (aparam: ${ORIG_TYPE}[${SIZE}]) -> ${CMP_TYPE}[${SIZE}] {
|
||||
%min_constant = ${CMP_TYPE}[] constant(-${CLAMP_TO})
|
||||
%max_constant = ${CMP_TYPE}[] constant(${CLAMP_TO})
|
||||
%min_values = ${CMP_TYPE}[${SIZE}] broadcast(%min_constant), dimensions={}
|
||||
%max_values = ${CMP_TYPE}[${SIZE}] broadcast(%max_constant), dimensions={}
|
||||
|
||||
%a = ${ORIG_TYPE}[${SIZE}] parameter(0)
|
||||
%converted = ${CMP_TYPE}[${SIZE}] convert(%a)
|
||||
ROOT %clamped = ${CMP_TYPE}[${SIZE}] clamp(%min_values, %converted, %max_values)
|
||||
}
|
||||
|
||||
// RelError(x, y) = abs(x - y) / (max(abs(x), abs(y)) + 1)
|
||||
// x and y must be finite.
|
||||
RelError (aparam: ${CMP_TYPE}[${SIZE}], bparam: ${CMP_TYPE}[${SIZE}]) -> ${CMP_TYPE}[${SIZE}] {
|
||||
%lhs = ${CMP_TYPE}[${SIZE}] parameter(0)
|
||||
%rhs = ${CMP_TYPE}[${SIZE}] parameter(1)
|
||||
%one_constant = ${CMP_TYPE}[] constant(1.0)
|
||||
%ones = ${CMP_TYPE}[${SIZE}] broadcast(%one_constant), dimensions={}
|
||||
|
||||
%sub = ${CMP_TYPE}[${SIZE}] subtract(%lhs, %rhs)
|
||||
%sub_abs = ${CMP_TYPE}[${SIZE}] abs(%sub)
|
||||
%lhs_abs = ${CMP_TYPE}[${SIZE}] abs(%lhs)
|
||||
%rhs_abs = ${CMP_TYPE}[${SIZE}] abs(%rhs)
|
||||
%max = ${CMP_TYPE}[${SIZE}] maximum(%lhs_abs, %rhs_abs)
|
||||
%denominator = ${CMP_TYPE}[${SIZE}] add(%max, %ones)
|
||||
ROOT %error = ${CMP_TYPE}[${SIZE}] divide(%sub_abs, %denominator)
|
||||
}
|
||||
|
||||
// Here is the chain-style definition of this function:
|
||||
// Error(NaN, NaN) = 0
|
||||
// Error(Inf, Inf) = 0
|
||||
// Error(-Inf, -Inf) = 0
|
||||
// Error(NonFinite, x) = Inf
|
||||
// Error(x, NonFinite) = Inf
|
||||
// Error(x, y) = RelError(x, y)
|
||||
// , where the early matched pattern takes precedence.
|
||||
// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length
|
||||
// buffer_length where the relative error does not exceed the passed
|
||||
// rel_error_threshold. Write the number of mismatches into out parameter
|
||||
// mismatch_count.
|
||||
//
|
||||
// To implement this, we start from the bottom, and keep using select to
|
||||
// overwrite previously picked values. The last value produced by a matched
|
||||
// pattern is the final value.
|
||||
Error (aparam: ${CMP_TYPE}[${SIZE}], bparam: ${CMP_TYPE}[${SIZE}]) -> ${CMP_TYPE}[${SIZE}] {
|
||||
%lhs = ${CMP_TYPE}[${SIZE}] parameter(0)
|
||||
%rhs = ${CMP_TYPE}[${SIZE}] parameter(1)
|
||||
%zero_constant = ${CMP_TYPE}[] constant(0.0)
|
||||
%inf_constant = ${CMP_TYPE}[] constant(inf)
|
||||
%zeros = ${CMP_TYPE}[${SIZE}] broadcast(%zero_constant), dimensions={}
|
||||
%infs = ${CMP_TYPE}[${SIZE}] broadcast(%inf_constant), dimensions={}
|
||||
// NaN's are considered equal, and for half's we clamp all numbers to largest
|
||||
// and smallest numbers representable to avoid miscomparisons due to overflows.
|
||||
//
|
||||
// The PTX below is compiled from the following CUDA code:
|
||||
//
|
||||
// #include<cuda_fp16.h>
|
||||
// extern "C" { // avoid name mangling
|
||||
// __device__ float canonicalize(float input) {
|
||||
// // All fp16 infinities are treated as 65505 or -65505, in order to avoid
|
||||
// // differences due to overflows.
|
||||
// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
|
||||
// }
|
||||
//
|
||||
// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
// int* mismatch_count) {
|
||||
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// if (idx >= buffer_length) return;
|
||||
// float elem_a = __half2float(buffer_a[idx]);
|
||||
// float elem_b = __half2float(buffer_b[idx]);
|
||||
// elem_a = canonicalize(elem_a);
|
||||
// elem_b = canonicalize(elem_b);
|
||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
// int* mismatch_count) {
|
||||
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// if (idx >= buffer_length) return;
|
||||
// float elem_a = buffer_a[idx];
|
||||
// float elem_b = buffer_b[idx];
|
||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||
// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
|
||||
// return;
|
||||
// float rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
//
|
||||
// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
|
||||
// float rel_error_threshold,
|
||||
// unsigned long long buffer_length,
|
||||
// int* mismatch_count) {
|
||||
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// if (idx >= buffer_length) return;
|
||||
// double elem_a = buffer_a[idx];
|
||||
// double elem_b = buffer_b[idx];
|
||||
// if (isnan(elem_a) && isnan(elem_b)) return;
|
||||
// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
|
||||
// return;
|
||||
// double rel_error = abs(elem_a - elem_b)
|
||||
// / (max(abs(elem_a), abs(elem_b)) + 1);
|
||||
// if (rel_error > rel_error_threshold || isnan(rel_error))
|
||||
// atomicAdd(mismatch_count, 1);
|
||||
// }
|
||||
// } // end extern declaration.
|
||||
static const char* buffer_compare_ptx = R"(
|
||||
.version 4.2
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
%lhs_is_finite = pred[${SIZE}] is-finite(%lhs)
|
||||
%lhs_is_not_finite = pred[${SIZE}] not(%lhs_is_finite)
|
||||
%lhs_is_not_nan = pred[${SIZE}] compare(%lhs, %lhs), direction=EQ
|
||||
%lhs_is_nan = pred[${SIZE}] not(%lhs_is_not_nan)
|
||||
%lhs_is_inf = pred[${SIZE}] and(%lhs_is_not_finite, %lhs_is_not_nan)
|
||||
%lhs_is_non_neg = pred[${SIZE}] compare(%lhs, %zeros), direction=GE
|
||||
.visible .entry __xla_fp16_comparison(
|
||||
.param .u64 __xla_fp16_comparison_param_0,
|
||||
.param .u64 __xla_fp16_comparison_param_1,
|
||||
.param .f32 __xla_fp16_comparison_param_2,
|
||||
.param .u64 __xla_fp16_comparison_param_3,
|
||||
.param .u64 __xla_fp16_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<10>;
|
||||
.reg .b16 %rs<3>;
|
||||
.reg .f32 %f<20>;
|
||||
.reg .b32 %r<6>;
|
||||
.reg .b64 %rd<12>;
|
||||
ld.param.u64 %rd8, [__xla_fp16_comparison_param_3];
|
||||
mov.u32 %r1, %tid.x;
|
||||
mov.u32 %r2, %ctaid.x;
|
||||
mov.u32 %r3, %ntid.x;
|
||||
mad.lo.s32 %r4, %r3, %r2, %r1;
|
||||
cvt.s64.s32 %rd4, %r4;
|
||||
setp.ge.u64 %p1, %rd4, %rd8;
|
||||
@%p1 bra LBB7_4;
|
||||
ld.param.u64 %rd5, [__xla_fp16_comparison_param_0];
|
||||
ld.param.u64 %rd7, [__xla_fp16_comparison_param_1];
|
||||
cvta.to.global.u64 %rd2, %rd7;
|
||||
cvta.to.global.u64 %rd3, %rd5;
|
||||
shl.b64 %rd9, %rd4, 1;
|
||||
add.s64 %rd10, %rd3, %rd9;
|
||||
ld.global.u16 %rs1, [%rd10];
|
||||
// begin inline asm
|
||||
{ cvt.f32.f16 %f6, %rs1;}
|
||||
|
||||
%rhs_is_finite = pred[${SIZE}] is-finite(%rhs)
|
||||
%rhs_is_not_finite = pred[${SIZE}] not(%rhs_is_finite)
|
||||
%rhs_is_not_nan = pred[${SIZE}] compare(%rhs, %rhs), direction=EQ
|
||||
%rhs_is_nan = pred[${SIZE}] not(%rhs_is_not_nan)
|
||||
%rhs_is_inf = pred[${SIZE}] and(%rhs_is_not_finite, %rhs_is_not_nan)
|
||||
%rhs_is_non_neg = pred[${SIZE}] compare(%rhs, %zeros), direction=GE
|
||||
// end inline asm
|
||||
add.s64 %rd11, %rd2, %rd9;
|
||||
ld.global.u16 %rs2, [%rd11];
|
||||
// begin inline asm
|
||||
{ cvt.f32.f16 %f7, %rs2;}
|
||||
|
||||
%both_same_sign = pred[${SIZE}] and(%lhs_is_non_neg, %rhs_is_non_neg)
|
||||
%both_inf = pred[${SIZE}] and(%lhs_is_inf, %rhs_is_inf)
|
||||
%both_same_sign_inf = pred[${SIZE}] and(%both_same_sign, %both_inf)
|
||||
%both_nan = pred[${SIZE}] and(%lhs_is_nan, %rhs_is_nan)
|
||||
// end inline asm
|
||||
abs.f32 %f8, %f6;
|
||||
setp.gtu.f32 %p2, %f8, 0f7F800000;
|
||||
min.f32 %f9, %f6, 0f477FE100;
|
||||
max.f32 %f10, %f9, 0fC77FE100;
|
||||
selp.f32 %f1, %f6, %f10, %p2;
|
||||
abs.f32 %f11, %f7;
|
||||
setp.gtu.f32 %p3, %f11, 0f7F800000;
|
||||
min.f32 %f12, %f7, 0f477FE100;
|
||||
max.f32 %f13, %f12, 0fC77FE100;
|
||||
selp.f32 %f2, %f7, %f13, %p3;
|
||||
abs.f32 %f3, %f1;
|
||||
setp.gtu.f32 %p4, %f3, 0f7F800000;
|
||||
abs.f32 %f4, %f2;
|
||||
setp.gtu.f32 %p5, %f4, 0f7F800000;
|
||||
and.pred %p6, %p4, %p5;
|
||||
@%p6 bra LBB7_4;
|
||||
ld.param.f32 %f5, [__xla_fp16_comparison_param_2];
|
||||
sub.f32 %f14, %f1, %f2;
|
||||
abs.f32 %f15, %f14;
|
||||
max.f32 %f16, %f3, %f4;
|
||||
add.f32 %f17, %f16, 0f3F800000;
|
||||
div.rn.f32 %f18, %f15, %f17;
|
||||
setp.leu.f32 %p7, %f18, %f5;
|
||||
abs.f32 %f19, %f18;
|
||||
setp.le.f32 %p8, %f19, 0f7F800000;
|
||||
and.pred %p9, %p7, %p8;
|
||||
@%p9 bra LBB7_4;
|
||||
ld.param.u64 %rd6, [__xla_fp16_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd6;
|
||||
atom.global.add.u32 %r5, [%rd1], 1;
|
||||
LBB7_4:
|
||||
ret;
|
||||
|
||||
// Reverse-order selections
|
||||
|
||||
// Error(x, y) = RelError(x, y)
|
||||
%rel_error = ${CMP_TYPE}[${SIZE}] call(%lhs, %rhs), to_apply=RelError
|
||||
// Error(x, NonFinite) = Inf
|
||||
%after_x_non_finite = ${CMP_TYPE}[${SIZE}] select(%rhs_is_not_finite, %infs, %rel_error)
|
||||
// Error(NonFinite, x) = Inf
|
||||
%after_non_finite_x = ${CMP_TYPE}[${SIZE}] select(%lhs_is_not_finite, %infs, %after_x_non_finite)
|
||||
// Error(-Inf, -Inf) = 0
|
||||
// Error(Inf, Inf) = 0
|
||||
%after_both_same_sign_inf = ${CMP_TYPE}[${SIZE}] select(%both_same_sign_inf, %zeros, %after_non_finite_x)
|
||||
// Error(NaN, NaN) = 0
|
||||
ROOT %after_both_nan = ${CMP_TYPE}[${SIZE}] select(%both_nan, %zeros, %after_both_same_sign_inf)
|
||||
}
|
||||
// .globl __xla_fp32_comparison
|
||||
.visible .entry __xla_fp32_comparison(
|
||||
.param .u64 __xla_fp32_comparison_param_0,
|
||||
.param .u64 __xla_fp32_comparison_param_1,
|
||||
.param .f32 __xla_fp32_comparison_param_2,
|
||||
.param .u64 __xla_fp32_comparison_param_3,
|
||||
.param .u64 __xla_fp32_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<12>;
|
||||
.reg .f32 %f<12>;
|
||||
.reg .b32 %r<9>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
ENTRY MaxDifference {
|
||||
%zero_constant = ${CMP_TYPE}[] constant(0.0)
|
||||
ld.param.u64 %rd8, [__xla_fp32_comparison_param_3];
|
||||
mov.u32 %r1, %tid.x;
|
||||
mov.u32 %r2, %ctaid.x;
|
||||
mov.u32 %r3, %ntid.x;
|
||||
mad.lo.s32 %r4, %r3, %r2, %r1;
|
||||
cvt.s64.s32 %rd4, %r4;
|
||||
setp.ge.u64 %p1, %rd4, %rd8;
|
||||
@%p1 bra LBB8_6;
|
||||
ld.param.u64 %rd5, [__xla_fp32_comparison_param_0];
|
||||
ld.param.u64 %rd7, [__xla_fp32_comparison_param_1];
|
||||
cvta.to.global.u64 %rd2, %rd7;
|
||||
cvta.to.global.u64 %rd3, %rd5;
|
||||
shl.b64 %rd9, %rd4, 2;
|
||||
add.s64 %rd10, %rd3, %rd9;
|
||||
ld.global.f32 %f1, [%rd10];
|
||||
add.s64 %rd11, %rd2, %rd9;
|
||||
ld.global.f32 %f2, [%rd11];
|
||||
abs.f32 %f3, %f1;
|
||||
setp.gtu.f32 %p2, %f3, 0f7F800000;
|
||||
abs.f32 %f4, %f2;
|
||||
setp.gtu.f32 %p3, %f4, 0f7F800000;
|
||||
and.pred %p4, %p2, %p3;
|
||||
@%p4 bra LBB8_6;
|
||||
setp.neu.f32 %p5, %f3, 0f7F800000;
|
||||
setp.neu.f32 %p6, %f4, 0f7F800000;
|
||||
or.pred %p7, %p5, %p6;
|
||||
@%p7 bra LBB8_4;
|
||||
mov.b32 %r5, %f1;
|
||||
mov.b32 %r6, %f2;
|
||||
xor.b32 %r7, %r6, %r5;
|
||||
setp.gt.s32 %p8, %r7, -1;
|
||||
@%p8 bra LBB8_6;
|
||||
LBB8_4:
|
||||
ld.param.f32 %f5, [__xla_fp32_comparison_param_2];
|
||||
sub.f32 %f6, %f1, %f2;
|
||||
abs.f32 %f7, %f6;
|
||||
max.f32 %f8, %f3, %f4;
|
||||
add.f32 %f9, %f8, 0f3F800000;
|
||||
div.rn.f32 %f10, %f7, %f9;
|
||||
setp.leu.f32 %p9, %f10, %f5;
|
||||
abs.f32 %f11, %f10;
|
||||
setp.le.f32 %p10, %f11, 0f7F800000;
|
||||
and.pred %p11, %p9, %p10;
|
||||
@%p11 bra LBB8_6;
|
||||
ld.param.u64 %rd6, [__xla_fp32_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd6;
|
||||
atom.global.add.u32 %r8, [%rd1], 1;
|
||||
LBB8_6:
|
||||
ret;
|
||||
|
||||
%lhs = ${ORIG_TYPE}[${SIZE}] parameter(0)
|
||||
%rhs = ${ORIG_TYPE}[${SIZE}] parameter(1)
|
||||
%lhs_canonical = ${CMP_TYPE}[${SIZE}] call(%lhs), to_apply=Canonicalize
|
||||
%rhs_canonical = ${CMP_TYPE}[${SIZE}] call(%rhs), to_apply=Canonicalize
|
||||
%error = ${CMP_TYPE}[${SIZE}] call(%lhs_canonical, %rhs_canonical), to_apply=Error
|
||||
%max_diff = ${CMP_TYPE}[] reduce(%error, %zero_constant), dimensions={0}, to_apply=Max
|
||||
ROOT %converted_max_diff = f64[] convert(%max_diff)
|
||||
})";
|
||||
}
|
||||
// .globl __xla_fp64_comparison
|
||||
.visible .entry __xla_fp64_comparison(
|
||||
.param .u64 __xla_fp64_comparison_param_0,
|
||||
.param .u64 __xla_fp64_comparison_param_1,
|
||||
.param .f32 __xla_fp64_comparison_param_2,
|
||||
.param .u64 __xla_fp64_comparison_param_3,
|
||||
.param .u64 __xla_fp64_comparison_param_4
|
||||
)
|
||||
{
|
||||
.reg .pred %p<16>;
|
||||
.reg .f32 %f<2>;
|
||||
.reg .b32 %r<13>;
|
||||
.reg .f64 %fd<12>;
|
||||
.reg .b64 %rd<12>;
|
||||
|
||||
absl::string_view orig_type;
|
||||
absl::string_view cmp_type;
|
||||
string clamp_to;
|
||||
|
||||
switch (shape.element_type()) {
|
||||
case xla::F16:
|
||||
orig_type = "f16";
|
||||
cmp_type = "f32";
|
||||
// Clamp fp16s to 65505, since they actually overflow a lot in practice.
|
||||
// This way, +infs and values like 65504 are considered be within
|
||||
// tolerance.
|
||||
clamp_to = "65505";
|
||||
break;
|
||||
case xla::F32:
|
||||
orig_type = "f32";
|
||||
cmp_type = "f32";
|
||||
clamp_to = "inf";
|
||||
break;
|
||||
case xla::F64:
|
||||
orig_type = "f64";
|
||||
cmp_type = "f64";
|
||||
clamp_to = "inf";
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("Unimplemented element type");
|
||||
ld.param.u64 %rd8, [__xla_fp64_comparison_param_3];
|
||||
mov.u32 %r2, %tid.x;
|
||||
mov.u32 %r3, %ctaid.x;
|
||||
mov.u32 %r4, %ntid.x;
|
||||
mad.lo.s32 %r5, %r4, %r3, %r2;
|
||||
cvt.s64.s32 %rd4, %r5;
|
||||
setp.ge.u64 %p1, %rd4, %rd8;
|
||||
@%p1 bra LBB9_6;
|
||||
ld.param.u64 %rd5, [__xla_fp64_comparison_param_0];
|
||||
ld.param.u64 %rd7, [__xla_fp64_comparison_param_1];
|
||||
cvta.to.global.u64 %rd2, %rd7;
|
||||
cvta.to.global.u64 %rd3, %rd5;
|
||||
shl.b64 %rd9, %rd4, 3;
|
||||
add.s64 %rd10, %rd3, %rd9;
|
||||
ld.global.f64 %fd1, [%rd10];
|
||||
add.s64 %rd11, %rd2, %rd9;
|
||||
ld.global.f64 %fd2, [%rd11];
|
||||
abs.f64 %fd3, %fd1;
|
||||
setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000;
|
||||
abs.f64 %fd4, %fd2;
|
||||
setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000;
|
||||
and.pred %p4, %p2, %p3;
|
||||
@%p4 bra LBB9_6;
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%r6, %temp}, %fd1;
|
||||
}
|
||||
|
||||
string size_str = absl::StrCat(ShapeUtil::ElementsIn(shape));
|
||||
return absl::StrReplaceAll(kCompHloText, {
|
||||
{"${ORIG_TYPE}", orig_type},
|
||||
{"${CMP_TYPE}", cmp_type},
|
||||
{"${SIZE}", size_str},
|
||||
{"${CLAMP_TO}", clamp_to},
|
||||
});
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%temp, %r1}, %fd1;
|
||||
}
|
||||
and.b32 %r7, %r1, 2147483647;
|
||||
setp.ne.s32 %p5, %r7, 2146435072;
|
||||
setp.ne.s32 %p6, %r6, 0;
|
||||
or.pred %p7, %p6, %p5;
|
||||
@%p7 bra LBB9_4;
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%r8, %temp}, %fd2;
|
||||
}
|
||||
{
|
||||
.reg .b32 %temp;
|
||||
mov.b64 {%temp, %r9}, %fd2;
|
||||
}
|
||||
and.b32 %r10, %r9, 2147483647;
|
||||
setp.eq.s32 %p8, %r10, 2146435072;
|
||||
setp.eq.s32 %p9, %r8, 0;
|
||||
and.pred %p10, %p8, %p9;
|
||||
xor.b32 %r11, %r9, %r1;
|
||||
setp.gt.s32 %p11, %r11, -1;
|
||||
and.pred %p12, %p11, %p10;
|
||||
@%p12 bra LBB9_6;
|
||||
LBB9_4:
|
||||
ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
|
||||
sub.f64 %fd5, %fd1, %fd2;
|
||||
abs.f64 %fd6, %fd5;
|
||||
max.f64 %fd7, %fd3, %fd4;
|
||||
add.f64 %fd8, %fd7, 0d3FF0000000000000;
|
||||
div.rn.f64 %fd9, %fd6, %fd8;
|
||||
cvt.f64.f32 %fd10, %f1;
|
||||
setp.leu.f64 %p13, %fd9, %fd10;
|
||||
abs.f64 %fd11, %fd9;
|
||||
setp.le.f64 %p14, %fd11, 0d7FF0000000000000;
|
||||
and.pred %p15, %p13, %p14;
|
||||
@%p15 bra LBB9_6;
|
||||
ld.param.u64 %rd6, [__xla_fp64_comparison_param_4];
|
||||
cvta.to.global.u64 %rd1, %rd6;
|
||||
atom.global.add.u32 %r12, [%rd1], 1;
|
||||
LBB9_6:
|
||||
ret;
|
||||
}
|
||||
)";
|
||||
|
||||
StatusOr<BufferComparator> BufferComparator::Create(
|
||||
const Shape& shape, se::StreamExecutor* stream_exec, Compiler* compiler) {
|
||||
// One may consider using hlo_runner to do all the compilation and execution.
|
||||
// However, as of the time hlo_runner doesn't support injection for Compiler*,
|
||||
// or Stream*. We may revisit this in the future if it
|
||||
// proves to be a maintenance burden.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto exec, ([&]() -> StatusOr<std::unique_ptr<Executable>> {
|
||||
HloModuleConfig config;
|
||||
DebugOptions debug_options;
|
||||
debug_options.set_xla_backend_optimization_level(2);
|
||||
config.set_debug_options(debug_options);
|
||||
TF_ASSIGN_OR_RETURN(string hlo_text, GetCompHloText(shape));
|
||||
TF_ASSIGN_OR_RETURN(auto module, ParseHloString(hlo_text, config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module,
|
||||
compiler->RunHloPasses(std::move(module), stream_exec, nullptr));
|
||||
return compiler->RunBackend(std::move(module), stream_exec, nullptr);
|
||||
}()));
|
||||
template <typename ElementT>
|
||||
using ComparisonKernelT =
|
||||
se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
|
||||
float, uint64, se::DeviceMemory<uint64>>;
|
||||
|
||||
return BufferComparator(shape, std::move(exec));
|
||||
}
|
||||
// Compares two buffers on the GPU.
|
||||
//
|
||||
// Returns `true` if two buffers are equal, `false` otherwise.
|
||||
template <typename ElementT>
|
||||
static StatusOr<bool> DeviceCompare(se::Stream* stream,
|
||||
se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs,
|
||||
const Shape& buffer_shape,
|
||||
const HloModuleConfig& config,
|
||||
absl::string_view kernel_name) {
|
||||
se::StreamExecutor* executor = stream->parent();
|
||||
|
||||
StatusOr<bool> BufferComparator::CompareEqualImpl(
|
||||
se::Stream* stream, DeviceMemoryAllocator* allocator,
|
||||
se::DeviceMemoryBase lhs, se::DeviceMemoryBase rhs) {
|
||||
se::ScopedDeviceMemory<uint64> out_param =
|
||||
executor->AllocateOwnedScalar<uint64>();
|
||||
|
||||
stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return InternalError("Mismatched buffer size: %d bytes vs %d bytes",
|
||||
return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
|
||||
lhs.size(), rhs.size());
|
||||
}
|
||||
|
||||
auto stream_exec = stream->parent();
|
||||
auto to_shaped_buffer =
|
||||
[stream_exec,
|
||||
this](se::DeviceMemoryBase buffer) -> StatusOr<ShapedBuffer> {
|
||||
auto device_ordinal = stream_exec->device_ordinal();
|
||||
ShapedBuffer shaped(shape_, shape_, stream_exec->platform(),
|
||||
device_ordinal);
|
||||
shaped.set_buffer(buffer, {});
|
||||
return std::move(shaped);
|
||||
};
|
||||
se::DeviceMemory<ElementT> lhs_typed(lhs);
|
||||
se::DeviceMemory<ElementT> rhs_typed(rhs);
|
||||
uint64 buffer_size = lhs_typed.ElementCount();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto shaped_lhs, to_shaped_buffer(lhs));
|
||||
TF_ASSIGN_OR_RETURN(auto shaped_rhs, to_shaped_buffer(rhs));
|
||||
PtxCompilationOptions opts(config);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::Span<const uint8> compiled_ptx,
|
||||
CompilePtxOrGetCached(executor, buffer_compare_ptx, opts));
|
||||
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_device_ordinal(stream_exec->device_ordinal());
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(allocator);
|
||||
ServiceExecutableRunOptions service_run_options(run_options);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
|
||||
(CreateTypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
|
||||
float, uint64, se::DeviceMemory<uint64>>(
|
||||
kernel_name, buffer_compare_ptx, compiled_ptx, executor)));
|
||||
|
||||
const ShapedBuffer* arg_buffers[] = {&shaped_lhs, &shaped_rhs};
|
||||
TF_ASSIGN_OR_RETURN(auto result_buffer,
|
||||
comparator_exec_->ExecuteOnStream(&service_run_options,
|
||||
arg_buffers, nullptr));
|
||||
LaunchDimensions dim =
|
||||
CalculateLaunchDimensions(buffer_shape, executor->GetDeviceDescription());
|
||||
|
||||
double result;
|
||||
CHECK(result_buffer.root_buffer().size() == sizeof(result));
|
||||
stream->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result));
|
||||
stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()),
|
||||
se::BlockDim(dim.block_count()), *comparison_kernel,
|
||||
lhs_typed, rhs_typed, static_cast<float>(kTolerance),
|
||||
buffer_size, out_param.cref());
|
||||
|
||||
uint64 result = -1;
|
||||
CHECK_EQ(out_param->size(), sizeof(result));
|
||||
stream->ThenMemcpy(&result, *out_param, sizeof(result));
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
return result < kTolerance;
|
||||
return result == 0;
|
||||
}
|
||||
|
||||
// Host side comparison code that does the same thing, but reports some of the
|
||||
// differences as well. It only print logs for debugging.
|
||||
//
|
||||
// Returns true if no differences were seen, false otherwise.
|
||||
template <typename ElementType, typename ComparisonType>
|
||||
Status HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs) {
|
||||
StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs) {
|
||||
int64 n = lhs.size() / sizeof(ElementType);
|
||||
std::vector<ElementType> host_lhs(n), host_rhs(n);
|
||||
stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
|
||||
@ -250,14 +404,11 @@ Status HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
|
||||
|
||||
const auto canonicalize = [](ComparisonType a) -> ComparisonType {
|
||||
if (std::is_same<ElementType, Eigen::half>::value && a) {
|
||||
constexpr float kMaxFp16Value = 65504.;
|
||||
constexpr ComparisonType kMaxFp16Value = 65505.;
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
if (a < 0) {
|
||||
return -(kMaxFp16Value + 1);
|
||||
}
|
||||
return kMaxFp16Value + 1;
|
||||
return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
|
||||
}
|
||||
return a;
|
||||
};
|
||||
@ -281,36 +432,49 @@ Status HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
|
||||
<< original_rhs;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
return differences_seen == 0;
|
||||
}
|
||||
|
||||
StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs) {
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
CompareEqualImpl(stream, allocator, lhs, rhs));
|
||||
template <typename ElementT, typename ComparisonT>
|
||||
static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
|
||||
se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs,
|
||||
const Shape& shape,
|
||||
const HloModuleConfig& config,
|
||||
absl::string_view kernel_name) {
|
||||
XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
|
||||
|
||||
if (result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (shape_.element_type()) {
|
||||
case xla::F16:
|
||||
TF_RETURN_IF_ERROR(HostCompare<Eigen::half, float>(stream, lhs, rhs));
|
||||
break;
|
||||
case xla::F32:
|
||||
TF_RETURN_IF_ERROR(HostCompare<float, float>(stream, lhs, rhs));
|
||||
break;
|
||||
case xla::F64:
|
||||
TF_RETURN_IF_ERROR(HostCompare<double, double>(stream, lhs, rhs));
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("Unimplemented element type");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool host_return,
|
||||
(HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
|
||||
CHECK(host_return == result) << "Different comparison result on GPU vs host";
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
|
||||
se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs) {
|
||||
switch (shape_.element_type()) {
|
||||
case xla::F16:
|
||||
return CompareEqualParameterized<Eigen::half, float>(
|
||||
stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
|
||||
case xla::F32:
|
||||
return CompareEqualParameterized<float, float>(
|
||||
stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
|
||||
case xla::F64:
|
||||
return CompareEqualParameterized<double, double>(
|
||||
stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
|
||||
default:
|
||||
return Unimplemented("Unimplemented element type");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
@ -16,9 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
@ -31,9 +30,8 @@ class BufferComparator {
|
||||
BufferComparator(const BufferComparator&) = delete;
|
||||
BufferComparator(BufferComparator&&) = default;
|
||||
|
||||
static StatusOr<BufferComparator> Create(const Shape& buffer_shape,
|
||||
se::StreamExecutor* stream_exec,
|
||||
Compiler* compiler);
|
||||
BufferComparator(const Shape& shape, const HloModuleConfig& config)
|
||||
: shape_(shape), config_(config) {}
|
||||
|
||||
// Returns true if the two buffers compare equal. The definition of "equal"
|
||||
// is:
|
||||
@ -45,21 +43,12 @@ class BufferComparator {
|
||||
//
|
||||
// See the implementation for the tolerance value.
|
||||
StatusOr<bool> CompareEqual(se::Stream* stream,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs);
|
||||
|
||||
private:
|
||||
BufferComparator(const Shape& shape, std::unique_ptr<Executable> exec)
|
||||
: shape_(shape), comparator_exec_(std::move(exec)) {}
|
||||
|
||||
StatusOr<bool> CompareEqualImpl(se::Stream* stream,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
se::DeviceMemoryBase lhs,
|
||||
se::DeviceMemoryBase rhs);
|
||||
|
||||
Shape shape_;
|
||||
std::unique_ptr<Executable> comparator_exec_;
|
||||
HloModuleConfig config_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
@ -16,10 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
|
||||
|
||||
#include <limits>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -28,11 +29,9 @@ namespace {
|
||||
class BufferComparatorTest : public testing::Test {
|
||||
protected:
|
||||
BufferComparatorTest()
|
||||
: backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()),
|
||||
stream_exec_(backend_->default_stream_executor()),
|
||||
allocator_(stream_exec_->platform(), {stream_exec_}),
|
||||
compiler_(Compiler::GetForPlatform(stream_exec_->platform())
|
||||
.ConsumeValueOrDie()) {}
|
||||
: platform_(
|
||||
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie()),
|
||||
stream_exec_(platform_->ExecutorForDevice(0).ValueOrDie()) {}
|
||||
|
||||
// Take floats only for convenience. Still uses ElementType internally.
|
||||
template <typename ElementType>
|
||||
@ -43,49 +42,26 @@ class BufferComparatorTest : public testing::Test {
|
||||
se::Stream stream(stream_exec_);
|
||||
stream.Init();
|
||||
|
||||
auto owning_lhs_buffer = allocator_
|
||||
.Allocate(stream_exec_->device_ordinal(),
|
||||
lhs.size() * sizeof(ElementType))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto owning_rhs_buffer = allocator_
|
||||
.Allocate(stream_exec_->device_ordinal(),
|
||||
rhs.size() * sizeof(ElementType))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto lhs_buffer =
|
||||
se::DeviceMemory<ElementType>(owning_lhs_buffer.AsDeviceMemoryBase());
|
||||
auto rhs_buffer =
|
||||
se::DeviceMemory<ElementType>(owning_rhs_buffer.AsDeviceMemoryBase());
|
||||
|
||||
stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size());
|
||||
stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size());
|
||||
se::ScopedDeviceMemory<ElementType> lhs_buffer =
|
||||
stream_exec_->AllocateOwnedArray<ElementType>(lhs.size());
|
||||
se::ScopedDeviceMemory<ElementType> rhs_buffer =
|
||||
stream_exec_->AllocateOwnedArray<ElementType>(lhs.size());
|
||||
|
||||
stream.ThenMemcpy(lhs_buffer.ptr(), lhs.data(), lhs_buffer->size());
|
||||
stream.ThenMemcpy(rhs_buffer.ptr(), rhs.data(), rhs_buffer->size());
|
||||
TF_CHECK_OK(stream.BlockHostUntilDone());
|
||||
|
||||
static auto* cmp_cache =
|
||||
new absl::flat_hash_map<std::pair<PrimitiveType, int64>,
|
||||
std::unique_ptr<BufferComparator>>();
|
||||
auto key =
|
||||
std::make_pair(primitive_util::NativeToPrimitiveType<ElementType>(),
|
||||
static_cast<int64>(lhs_buffer.ElementCount()));
|
||||
std::unique_ptr<BufferComparator>& comparator = (*cmp_cache)[key];
|
||||
if (!comparator) {
|
||||
comparator.reset(new BufferComparator(
|
||||
BufferComparator::Create(
|
||||
ShapeUtil::MakeShape(key.first, {key.second}), stream.parent(),
|
||||
compiler_)
|
||||
.ConsumeValueOrDie()));
|
||||
}
|
||||
return comparator
|
||||
->CompareEqual(&stream, &allocator_, lhs_buffer, rhs_buffer)
|
||||
BufferComparator comparator(
|
||||
ShapeUtil::MakeShape(
|
||||
primitive_util::NativeToPrimitiveType<ElementType>(),
|
||||
{static_cast<int64>(lhs_buffer->ElementCount())}),
|
||||
HloModuleConfig());
|
||||
return comparator.CompareEqual(&stream, *lhs_buffer, *rhs_buffer)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
std::unique_ptr<Backend> backend_;
|
||||
se::Platform* platform_;
|
||||
se::StreamExecutor* stream_exec_;
|
||||
StreamExecutorMemoryAllocator allocator_;
|
||||
Compiler* compiler_;
|
||||
};
|
||||
|
||||
TEST_F(BufferComparatorTest, TestNaNs) {
|
||||
|
||||
@ -392,7 +392,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
if (comparator.has_value()) {
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
|
||||
StatusOr<bool> compare_result = comparator->CompareEqual(
|
||||
&stream, allocator, reference_result_buffer, result_buffer);
|
||||
&stream, reference_result_buffer, result_buffer);
|
||||
if (!compare_result.ok()) {
|
||||
LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm)
|
||||
<< " against " << AlgorithmToString(alg) << " for "
|
||||
@ -420,21 +420,13 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
}
|
||||
} else {
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::Create", 2);
|
||||
auto comp =
|
||||
BufferComparator::Create(result_shape, stream.parent(), compiler_);
|
||||
if (comp.ok()) {
|
||||
comparator.emplace(comp.ConsumeValueOrDie());
|
||||
reference_result_buffer = result_buffer;
|
||||
TF_ASSIGN_OR_RETURN(result_buffer,
|
||||
input_output_allocator.AllocateBytes(
|
||||
&stream, reference_result_buffer.size()));
|
||||
initialize_buffer(result_buffer);
|
||||
first_algorithm = alg;
|
||||
} else {
|
||||
LOG(ERROR) << "Fail to initialize buffer comparator: " << comp.status()
|
||||
<< ", instruction: " << instr->ToString();
|
||||
CHECK(!crash_on_checking_failure);
|
||||
}
|
||||
comparator.emplace(result_shape, hlo_module_config);
|
||||
reference_result_buffer = result_buffer;
|
||||
TF_ASSIGN_OR_RETURN(result_buffer,
|
||||
input_output_allocator.AllocateBytes(
|
||||
&stream, reference_result_buffer.size()));
|
||||
initialize_buffer(result_buffer);
|
||||
first_algorithm = alg;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -153,37 +153,6 @@ LBB6_3:
|
||||
using ComparisonKernelT = se::TypedKernel<se::DeviceMemory<uint8>, uint8,
|
||||
uint64, se::DeviceMemory<uint64>>;
|
||||
|
||||
// Compile PTX in redzone_checker_ptx, or get a cached compiled version (for a
|
||||
// given stream executor and a given CUDA directory specified by an XLA flag).
|
||||
static StatusOr<absl::Span<const uint8>> CompileRedzoneCheckPtxOrGetCached(
|
||||
se::StreamExecutor* executor, const HloModuleConfig& hlo_module_config) {
|
||||
// Cache for storing the compiled PTX for redzone checking.
|
||||
// The cache key is a stream executor, as it determines the supported
|
||||
// CUDA compute capability, and PtxCompilationOptions.
|
||||
using PtxCacheKey =
|
||||
std::pair<se::StreamExecutor*, PtxCompilationOptions::PtxOptionsTuple>;
|
||||
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
static auto& redzone_check_ptx_cache GUARDED_BY(ptx_cache_mutex) =
|
||||
*new absl::flat_hash_map<PtxCacheKey, std::vector<uint8>>();
|
||||
|
||||
tensorflow::mutex_lock lock(ptx_cache_mutex);
|
||||
PtxCompilationOptions compilation_options(hlo_module_config);
|
||||
PtxCacheKey cache_key{executor, compilation_options.ToTuple()};
|
||||
auto it = redzone_check_ptx_cache.find(cache_key);
|
||||
if (it != redzone_check_ptx_cache.end()) {
|
||||
return {it->second};
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<uint8> compiled,
|
||||
CompilePtx(executor, redzone_checker_ptx, compilation_options));
|
||||
|
||||
auto insert_result =
|
||||
redzone_check_ptx_cache.emplace(cache_key, std::move(compiled));
|
||||
CHECK(insert_result.second);
|
||||
return {insert_result.first->second};
|
||||
}
|
||||
|
||||
// Check that redzones weren't overwritten on a host.
|
||||
//
|
||||
// Slower, but gives a more useful error message.
|
||||
@ -297,7 +266,8 @@ Status RedzoneAllocator::CheckRedzones(se::Stream* stream) const {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::Span<const uint8> compiled_ptx,
|
||||
CompileRedzoneCheckPtxOrGetCached(executor, hlo_module_config_));
|
||||
CompilePtxOrGetCached(executor, redzone_checker_ptx,
|
||||
PtxCompilationOptions(hlo_module_config_)));
|
||||
|
||||
se::ScopedDeviceMemory<uint64> out_param =
|
||||
executor->AllocateOwnedScalar<uint64>();
|
||||
@ -305,8 +275,7 @@ Status RedzoneAllocator::CheckRedzones(se::Stream* stream) const {
|
||||
|
||||
auto typed_or = CreateTypedKernel<se::DeviceMemory<uint8>, uint8, uint64,
|
||||
se::DeviceMemory<uint64>>(
|
||||
"redzone_checker",
|
||||
/*num_args=*/4, redzone_checker_ptx, compiled_ptx, executor);
|
||||
"redzone_checker", redzone_checker_ptx, compiled_ptx, executor);
|
||||
|
||||
// TF_ASSIGN_OR_RETURN does not work due to complex template.
|
||||
if (!typed_or.ok()) {
|
||||
|
||||
@ -320,6 +320,30 @@ std::vector<string> GetCudaRootCandidates(
|
||||
return potential_cuda_roots;
|
||||
}
|
||||
|
||||
StatusOr<absl::Span<const uint8>> CompilePtxOrGetCached(
|
||||
se::StreamExecutor* executor, absl::string_view ptx,
|
||||
PtxCompilationOptions compilation_options) {
|
||||
using PtxCacheKey = std::tuple<se::StreamExecutor*, std::string,
|
||||
PtxCompilationOptions::PtxOptionsTuple>;
|
||||
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
static auto& ptx_cache GUARDED_BY(ptx_cache_mutex) =
|
||||
*new absl::flat_hash_map<PtxCacheKey, std::vector<uint8>>();
|
||||
|
||||
tensorflow::mutex_lock lock(ptx_cache_mutex);
|
||||
PtxCacheKey cache_key{executor, std::string(ptx),
|
||||
compilation_options.ToTuple()};
|
||||
auto it = ptx_cache.find(cache_key);
|
||||
if (it == ptx_cache.end()) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<uint8> compiled,
|
||||
CompilePtx(executor, ptx, compilation_options));
|
||||
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
|
||||
}
|
||||
|
||||
CHECK(it != ptx_cache.end());
|
||||
const std::vector<uint8>& compiled = it->second;
|
||||
return absl::MakeSpan(compiled);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<uint8>> CompilePtx(
|
||||
se::StreamExecutor* stream_exec, absl::string_view ptx,
|
||||
PtxCompilationOptions compile_ptx_options) {
|
||||
|
||||
@ -57,9 +57,11 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
|
||||
// device while another thread is using it.
|
||||
tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec);
|
||||
|
||||
// Creates a type-safe kernel which can be launched with stream.ThenLaunch.
|
||||
// Creates a kernel which can be launched with stream.ThenLaunch, such that
|
||||
// the types of the arguments provided for launch would have to match
|
||||
// types of the arguments provided at creation time.
|
||||
//
|
||||
// The kernel has a provided name, and is based from provided PTX in ptx,
|
||||
// The kernel has a name kernel_name, and is based from provided PTX in ptx,
|
||||
// and (optional) compiled PTX in cubin_data.
|
||||
// The canonical storage for both ptx and cubin_data should outlive the
|
||||
// lifetime of the kernel.
|
||||
@ -67,9 +69,10 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec);
|
||||
// This is a preferred API since it provides type safety for kernel launches.
|
||||
template <typename... Args>
|
||||
StatusOr<std::unique_ptr<se::TypedKernel<Args...>>> CreateTypedKernel(
|
||||
absl::string_view kernel_name, uint64 num_args, absl::string_view ptx,
|
||||
absl::string_view kernel_name, absl::string_view ptx,
|
||||
absl::Span<const uint8> cubin_data, se::StreamExecutor* stream_exec) {
|
||||
se::MultiKernelLoaderSpec loader_spec(num_args);
|
||||
auto kernel_base = absl::make_unique<se::TypedKernel<Args...>>(stream_exec);
|
||||
se::MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters);
|
||||
loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
|
||||
|
||||
if (!cubin_data.empty()) {
|
||||
@ -77,7 +80,6 @@ StatusOr<std::unique_ptr<se::TypedKernel<Args...>>> CreateTypedKernel(
|
||||
reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
|
||||
}
|
||||
|
||||
auto kernel_base = absl::make_unique<se::TypedKernel<Args...>>(stream_exec);
|
||||
if (!stream_exec->GetKernel(loader_spec, kernel_base.get())) {
|
||||
return InternalError("Unable to load kernel '%s'", kernel_name);
|
||||
}
|
||||
@ -135,6 +137,14 @@ StatusOr<std::vector<uint8>> CompilePtx(
|
||||
se::StreamExecutor* stream_exec, absl::string_view ptx,
|
||||
PtxCompilationOptions compile_ptx_options);
|
||||
|
||||
// Same as CompilePtx, but caches the result, and returns unowned view of
|
||||
// the compiled binary.
|
||||
//
|
||||
// A copy of the string provided in ptx will be made.
|
||||
StatusOr<absl::Span<const uint8>> CompilePtxOrGetCached(
|
||||
se::StreamExecutor* executor, absl::string_view ptx,
|
||||
PtxCompilationOptions compilation_options);
|
||||
|
||||
// Returns a vector of potential locations of the CUDA root directory.
|
||||
// Searches through tensorflow CUDA locations AND through the CUDA location
|
||||
// specified in compile_ptx_options (can be constructed from HloModuleConfig).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user