Don't use BufferComparator if no autotuning is done.
If we have a cache hit, there is no need to initialize a buffer with RedzoneAllocator to verify results. So we move this logic from RunOnInstruction to DoUncachedGemmAutotune. PiperOrigin-RevId: 300707327 Change-Id: I4cda5733de08d16629f19fef8d3e99045c84af7a
This commit is contained in:
parent
5d1b37c014
commit
0f092a67a3
@ -58,15 +58,48 @@ static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0;
|
||||
// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
|
||||
// all.
|
||||
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer,
|
||||
se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer,
|
||||
se::DeviceMemoryBase reference_result_buffer, se::Stream* stream,
|
||||
const se::RedzoneAllocator& allocator, const BufferComparator& comparator,
|
||||
bool crash_on_checking_failure) {
|
||||
const HloInstruction* gemm, se::Stream* stream,
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
if (!stream->parent()->SynchronizeAllActivity()) {
|
||||
return InternalError("Failed to synchronize GPU for autotuning.");
|
||||
}
|
||||
|
||||
const HloModuleConfig& hlo_module_config = gemm->GetModule()->config();
|
||||
const bool init_cublas_data =
|
||||
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
|
||||
se::RedzoneAllocator input_output_allocator(
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config),
|
||||
/*memory_limit=*/std::numeric_limits<int64>::max());
|
||||
|
||||
BufferComparator comparator(gemm->shape(), hlo_module_config);
|
||||
|
||||
int64 rng_state = 0;
|
||||
auto get_initialized_buffer =
|
||||
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
|
||||
input_output_allocator.AllocateBytes(
|
||||
ShapeUtil::ByteSizeOf(op->shape())));
|
||||
if (init_cublas_data) {
|
||||
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
|
||||
}
|
||||
return buffer;
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
|
||||
get_initialized_buffer(gemm->operand(0)));
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
|
||||
get_initialized_buffer(gemm->operand(1)));
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
|
||||
get_initialized_buffer(gemm));
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
|
||||
get_initialized_buffer(gemm));
|
||||
|
||||
const DebugOptions& debug_options =
|
||||
gemm->GetModule()->config().debug_options();
|
||||
|
||||
const bool crash_on_checking_failure =
|
||||
debug_options.xla_gpu_crash_on_verification_failures();
|
||||
|
||||
GemmBackendConfig backend_config =
|
||||
gemm->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||
const int32 cublas_autotune_level =
|
||||
@ -124,7 +157,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
|
||||
allocator.CheckRedzones());
|
||||
input_output_allocator.CheckRedzones());
|
||||
if (!rz_check_status.ok()) {
|
||||
result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
|
||||
*result.mutable_failure()->mutable_msg() =
|
||||
@ -194,17 +227,14 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
}
|
||||
|
||||
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
||||
const HloInstruction* instr, const HloInstruction* lhs,
|
||||
const HloInstruction* rhs, se::DeviceMemoryBase lhs_buffer,
|
||||
se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer,
|
||||
se::DeviceMemoryBase reference_result_buffer, se::Stream* stream,
|
||||
bool crash_on_checking_failure, const se::RedzoneAllocator& allocator,
|
||||
const BufferComparator& comparator) {
|
||||
const HloInstruction* instr, const GemmBackendConfig& gemm_config,
|
||||
se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
|
||||
const HloInstruction* lhs = instr->operand(0);
|
||||
const HloInstruction* rhs = instr->operand(1);
|
||||
|
||||
// Don't run autotuning concurrently on the same GPU.
|
||||
tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
|
||||
|
||||
GemmBackendConfig gemm_config =
|
||||
instr->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||
|
||||
GemmCacheKey key =
|
||||
std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
|
||||
@ -235,11 +265,8 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
||||
VLOG(2) << "Batch size is non-singular, using generic algorithm";
|
||||
result = absl::nullopt;
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result,
|
||||
DoUncachedGemmAutotune(instr, lhs_buffer, rhs_buffer, output_buffer,
|
||||
reference_result_buffer, stream, allocator,
|
||||
comparator, crash_on_checking_failure));
|
||||
TF_ASSIGN_OR_RETURN(result,
|
||||
DoUncachedGemmAutotune(instr, stream, allocator));
|
||||
}
|
||||
|
||||
CHECK(autotune_cache.emplace(key, result).second);
|
||||
@ -255,52 +282,11 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
|
||||
allocator->GetStream(executor->device_ordinal()));
|
||||
|
||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
||||
const bool init_cublas_data =
|
||||
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
|
||||
se::RedzoneAllocator input_output_allocator(
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config),
|
||||
/*memory_limit=*/std::numeric_limits<int64>::max());
|
||||
|
||||
BufferComparator comparator(instr->shape(), hlo_module_config);
|
||||
|
||||
int64 rng_state = 0;
|
||||
auto get_initialized_buffer =
|
||||
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
|
||||
input_output_allocator.AllocateBytes(
|
||||
ShapeUtil::ByteSizeOf(op->shape())));
|
||||
if (init_cublas_data) {
|
||||
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
|
||||
}
|
||||
return buffer;
|
||||
};
|
||||
|
||||
GemmBackendConfig gemm_config =
|
||||
instr->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||
const HloInstruction* lhs = instr->operand(0);
|
||||
const HloInstruction* rhs = instr->operand(1);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
|
||||
get_initialized_buffer(lhs));
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
|
||||
get_initialized_buffer(rhs));
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
|
||||
get_initialized_buffer(instr));
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
|
||||
get_initialized_buffer(instr));
|
||||
|
||||
const DebugOptions& debug_options =
|
||||
instr->GetModule()->config().debug_options();
|
||||
|
||||
const bool crash_on_checking_failure =
|
||||
debug_options.xla_gpu_crash_on_verification_failures();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::optional<se::blas::AlgorithmType> gemm_algorithm,
|
||||
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer,
|
||||
reference_result_buffer, stream, crash_on_checking_failure,
|
||||
input_output_allocator, comparator));
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
|
||||
DoGemmAutotune(instr, gemm_config, allocator, stream));
|
||||
|
||||
// We update instruction->backend_config(); if no algorithms are supported,
|
||||
// a different API is used, which does not require specifying an algorithm.
|
||||
|
Loading…
Reference in New Issue
Block a user