Don't use BufferComparator if no autotuning is done.
If we have a cache hit, there is no need to initialize a buffer with RedzoneAllocator to verify results. So we move this logic from RunOnInstruction to DoUncachedGemmAutotune. PiperOrigin-RevId: 300707327 Change-Id: I4cda5733de08d16629f19fef8d3e99045c84af7a
This commit is contained in:
parent
5d1b37c014
commit
0f092a67a3
@ -58,15 +58,48 @@ static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0;
|
|||||||
// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
|
// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
|
||||||
// all.
|
// all.
|
||||||
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||||
const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer,
|
const HloInstruction* gemm, se::Stream* stream,
|
||||||
se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer,
|
se::DeviceMemoryAllocator* allocator) {
|
||||||
se::DeviceMemoryBase reference_result_buffer, se::Stream* stream,
|
|
||||||
const se::RedzoneAllocator& allocator, const BufferComparator& comparator,
|
|
||||||
bool crash_on_checking_failure) {
|
|
||||||
if (!stream->parent()->SynchronizeAllActivity()) {
|
if (!stream->parent()->SynchronizeAllActivity()) {
|
||||||
return InternalError("Failed to synchronize GPU for autotuning.");
|
return InternalError("Failed to synchronize GPU for autotuning.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const HloModuleConfig& hlo_module_config = gemm->GetModule()->config();
|
||||||
|
const bool init_cublas_data =
|
||||||
|
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
|
||||||
|
se::RedzoneAllocator input_output_allocator(
|
||||||
|
stream, allocator, PtxOptsFromConfig(hlo_module_config),
|
||||||
|
/*memory_limit=*/std::numeric_limits<int64>::max());
|
||||||
|
|
||||||
|
BufferComparator comparator(gemm->shape(), hlo_module_config);
|
||||||
|
|
||||||
|
int64 rng_state = 0;
|
||||||
|
auto get_initialized_buffer =
|
||||||
|
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
|
||||||
|
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
|
||||||
|
input_output_allocator.AllocateBytes(
|
||||||
|
ShapeUtil::ByteSizeOf(op->shape())));
|
||||||
|
if (init_cublas_data) {
|
||||||
|
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
|
||||||
|
get_initialized_buffer(gemm->operand(0)));
|
||||||
|
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
|
||||||
|
get_initialized_buffer(gemm->operand(1)));
|
||||||
|
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
|
||||||
|
get_initialized_buffer(gemm));
|
||||||
|
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
|
||||||
|
get_initialized_buffer(gemm));
|
||||||
|
|
||||||
|
const DebugOptions& debug_options =
|
||||||
|
gemm->GetModule()->config().debug_options();
|
||||||
|
|
||||||
|
const bool crash_on_checking_failure =
|
||||||
|
debug_options.xla_gpu_crash_on_verification_failures();
|
||||||
|
|
||||||
GemmBackendConfig backend_config =
|
GemmBackendConfig backend_config =
|
||||||
gemm->backend_config<GemmBackendConfig>().ValueOrDie();
|
gemm->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||||
const int32 cublas_autotune_level =
|
const int32 cublas_autotune_level =
|
||||||
@ -124,7 +157,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
|
se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
|
||||||
allocator.CheckRedzones());
|
input_output_allocator.CheckRedzones());
|
||||||
if (!rz_check_status.ok()) {
|
if (!rz_check_status.ok()) {
|
||||||
result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
|
result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
|
||||||
*result.mutable_failure()->mutable_msg() =
|
*result.mutable_failure()->mutable_msg() =
|
||||||
@ -194,17 +227,14 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
|||||||
}
|
}
|
||||||
|
|
||||||
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
||||||
const HloInstruction* instr, const HloInstruction* lhs,
|
const HloInstruction* instr, const GemmBackendConfig& gemm_config,
|
||||||
const HloInstruction* rhs, se::DeviceMemoryBase lhs_buffer,
|
se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
|
||||||
se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer,
|
const HloInstruction* lhs = instr->operand(0);
|
||||||
se::DeviceMemoryBase reference_result_buffer, se::Stream* stream,
|
const HloInstruction* rhs = instr->operand(1);
|
||||||
bool crash_on_checking_failure, const se::RedzoneAllocator& allocator,
|
|
||||||
const BufferComparator& comparator) {
|
|
||||||
// Don't run autotuning concurrently on the same GPU.
|
// Don't run autotuning concurrently on the same GPU.
|
||||||
tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
|
tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
|
||||||
|
|
||||||
GemmBackendConfig gemm_config =
|
|
||||||
instr->backend_config<GemmBackendConfig>().ValueOrDie();
|
|
||||||
|
|
||||||
GemmCacheKey key =
|
GemmCacheKey key =
|
||||||
std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
|
std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
|
||||||
@ -235,11 +265,8 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
|||||||
VLOG(2) << "Batch size is non-singular, using generic algorithm";
|
VLOG(2) << "Batch size is non-singular, using generic algorithm";
|
||||||
result = absl::nullopt;
|
result = absl::nullopt;
|
||||||
} else {
|
} else {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(result,
|
||||||
result,
|
DoUncachedGemmAutotune(instr, stream, allocator));
|
||||||
DoUncachedGemmAutotune(instr, lhs_buffer, rhs_buffer, output_buffer,
|
|
||||||
reference_result_buffer, stream, allocator,
|
|
||||||
comparator, crash_on_checking_failure));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(autotune_cache.emplace(key, result).second);
|
CHECK(autotune_cache.emplace(key, result).second);
|
||||||
@ -255,52 +282,11 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
|||||||
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
|
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
|
||||||
allocator->GetStream(executor->device_ordinal()));
|
allocator->GetStream(executor->device_ordinal()));
|
||||||
|
|
||||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
|
||||||
const bool init_cublas_data =
|
|
||||||
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
|
|
||||||
se::RedzoneAllocator input_output_allocator(
|
|
||||||
stream, allocator, PtxOptsFromConfig(hlo_module_config),
|
|
||||||
/*memory_limit=*/std::numeric_limits<int64>::max());
|
|
||||||
|
|
||||||
BufferComparator comparator(instr->shape(), hlo_module_config);
|
|
||||||
|
|
||||||
int64 rng_state = 0;
|
|
||||||
auto get_initialized_buffer =
|
|
||||||
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
|
|
||||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
|
|
||||||
input_output_allocator.AllocateBytes(
|
|
||||||
ShapeUtil::ByteSizeOf(op->shape())));
|
|
||||||
if (init_cublas_data) {
|
|
||||||
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
|
|
||||||
}
|
|
||||||
return buffer;
|
|
||||||
};
|
|
||||||
|
|
||||||
GemmBackendConfig gemm_config =
|
GemmBackendConfig gemm_config =
|
||||||
instr->backend_config<GemmBackendConfig>().ValueOrDie();
|
instr->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||||
const HloInstruction* lhs = instr->operand(0);
|
|
||||||
const HloInstruction* rhs = instr->operand(1);
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
|
TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
|
||||||
get_initialized_buffer(lhs));
|
DoGemmAutotune(instr, gemm_config, allocator, stream));
|
||||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
|
|
||||||
get_initialized_buffer(rhs));
|
|
||||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
|
|
||||||
get_initialized_buffer(instr));
|
|
||||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
|
|
||||||
get_initialized_buffer(instr));
|
|
||||||
|
|
||||||
const DebugOptions& debug_options =
|
|
||||||
instr->GetModule()->config().debug_options();
|
|
||||||
|
|
||||||
const bool crash_on_checking_failure =
|
|
||||||
debug_options.xla_gpu_crash_on_verification_failures();
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
absl::optional<se::blas::AlgorithmType> gemm_algorithm,
|
|
||||||
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer,
|
|
||||||
reference_result_buffer, stream, crash_on_checking_failure,
|
|
||||||
input_output_allocator, comparator));
|
|
||||||
|
|
||||||
// We update instruction->backend_config(); if no algorithms are supported,
|
// We update instruction->backend_config(); if no algorithms are supported,
|
||||||
// a different API is used, which does not require specifying an algorithm.
|
// a different API is used, which does not require specifying an algorithm.
|
||||||
|
Loading…
Reference in New Issue
Block a user