Don't use BufferComparator if no autotuning is done.

If we have a cache hit, there is no need to initialize a buffer with
RedzoneAllocator to verify results. So we move this logic from
RunOnInstruction to DoUncachedGemmAutotune.

PiperOrigin-RevId: 300707327
Change-Id: I4cda5733de08d16629f19fef8d3e99045c84af7a
This commit is contained in:
Adrian Kuegel 2020-03-13 00:49:57 -07:00 committed by TensorFlower Gardener
parent 5d1b37c014
commit 0f092a67a3

View File

@ -58,15 +58,48 @@ static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0;
// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
// all.
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer,
se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer,
se::DeviceMemoryBase reference_result_buffer, se::Stream* stream,
const se::RedzoneAllocator& allocator, const BufferComparator& comparator,
bool crash_on_checking_failure) {
const HloInstruction* gemm, se::Stream* stream,
se::DeviceMemoryAllocator* allocator) {
if (!stream->parent()->SynchronizeAllActivity()) {
return InternalError("Failed to synchronize GPU for autotuning.");
}
const HloModuleConfig& hlo_module_config = gemm->GetModule()->config();
const bool init_cublas_data =
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
se::RedzoneAllocator input_output_allocator(
stream, allocator, PtxOptsFromConfig(hlo_module_config),
/*memory_limit=*/std::numeric_limits<int64>::max());
BufferComparator comparator(gemm->shape(), hlo_module_config);
int64 rng_state = 0;
auto get_initialized_buffer =
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(op->shape())));
if (init_cublas_data) {
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
}
return buffer;
};
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
get_initialized_buffer(gemm->operand(0)));
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
get_initialized_buffer(gemm->operand(1)));
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
get_initialized_buffer(gemm));
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
get_initialized_buffer(gemm));
const DebugOptions& debug_options =
gemm->GetModule()->config().debug_options();
const bool crash_on_checking_failure =
debug_options.xla_gpu_crash_on_verification_failures();
GemmBackendConfig backend_config =
gemm->backend_config<GemmBackendConfig>().ValueOrDie();
const int32 cublas_autotune_level =
@ -124,7 +157,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
TF_ASSIGN_OR_RETURN(
se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
allocator.CheckRedzones());
input_output_allocator.CheckRedzones());
if (!rz_check_status.ok()) {
result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
*result.mutable_failure()->mutable_msg() =
@ -194,17 +227,14 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
}
static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
const HloInstruction* instr, const HloInstruction* lhs,
const HloInstruction* rhs, se::DeviceMemoryBase lhs_buffer,
se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer,
se::DeviceMemoryBase reference_result_buffer, se::Stream* stream,
bool crash_on_checking_failure, const se::RedzoneAllocator& allocator,
const BufferComparator& comparator) {
const HloInstruction* instr, const GemmBackendConfig& gemm_config,
se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
const HloInstruction* lhs = instr->operand(0);
const HloInstruction* rhs = instr->operand(1);
// Don't run autotuning concurrently on the same GPU.
tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
GemmBackendConfig gemm_config =
instr->backend_config<GemmBackendConfig>().ValueOrDie();
GemmCacheKey key =
std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
@ -235,11 +265,8 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
VLOG(2) << "Batch size is non-singular, using generic algorithm";
result = absl::nullopt;
} else {
TF_ASSIGN_OR_RETURN(
result,
DoUncachedGemmAutotune(instr, lhs_buffer, rhs_buffer, output_buffer,
reference_result_buffer, stream, allocator,
comparator, crash_on_checking_failure));
TF_ASSIGN_OR_RETURN(result,
DoUncachedGemmAutotune(instr, stream, allocator));
}
CHECK(autotune_cache.emplace(key, result).second);
@ -255,52 +282,11 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
TF_ASSIGN_OR_RETURN(se::Stream* const stream,
allocator->GetStream(executor->device_ordinal()));
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
const bool init_cublas_data =
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
se::RedzoneAllocator input_output_allocator(
stream, allocator, PtxOptsFromConfig(hlo_module_config),
/*memory_limit=*/std::numeric_limits<int64>::max());
BufferComparator comparator(instr->shape(), hlo_module_config);
int64 rng_state = 0;
auto get_initialized_buffer =
[&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(op->shape())));
if (init_cublas_data) {
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
}
return buffer;
};
GemmBackendConfig gemm_config =
instr->backend_config<GemmBackendConfig>().ValueOrDie();
const HloInstruction* lhs = instr->operand(0);
const HloInstruction* rhs = instr->operand(1);
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
get_initialized_buffer(lhs));
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
get_initialized_buffer(rhs));
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
get_initialized_buffer(instr));
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
get_initialized_buffer(instr));
const DebugOptions& debug_options =
instr->GetModule()->config().debug_options();
const bool crash_on_checking_failure =
debug_options.xla_gpu_crash_on_verification_failures();
TF_ASSIGN_OR_RETURN(
absl::optional<se::blas::AlgorithmType> gemm_algorithm,
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer,
reference_result_buffer, stream, crash_on_checking_failure,
input_output_allocator, comparator));
TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
DoGemmAutotune(instr, gemm_config, allocator, stream));
// We update instruction->backend_config(); if no algorithms are supported,
// a different API is used, which does not require specifying an algorithm.