Merge pull request #35621 from bas-aarts:bas_auto_tune_level
PiperOrigin-RevId: 291051983 Change-Id: I9518cca1eb07e44be7857b0c3d87c94de420fdce
This commit is contained in:
commit
a836923de3
@ -34,6 +34,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
||||
opts.set_xla_llvm_enable_invariant_load_metadata(true);
|
||||
opts.set_xla_llvm_disable_expensive_passes(false);
|
||||
opts.set_xla_backend_optimization_level(3);
|
||||
opts.set_xla_gpu_autotune_level(4);
|
||||
opts.set_xla_cpu_multi_thread_eigen(true);
|
||||
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
|
||||
opts.set_xla_eliminate_hlo_implicit_broadcast(true);
|
||||
@ -399,10 +400,12 @@ static void AllocateFlags() {
|
||||
"Crashes the program on extra verification failures, e.g. cuDNN "
|
||||
"cross checking failures"),
|
||||
tensorflow::Flag(
|
||||
"xla_gpu_disable_autotune",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_disable_autotune),
|
||||
flag_values->xla_gpu_disable_autotune(),
|
||||
"Disable GEMM and Convolution auto-tuning."),
|
||||
"xla_gpu_autotune_level",
|
||||
int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
|
||||
flag_values->xla_gpu_autotune_level(),
|
||||
"Set GEMM and Convolution auto-tuning level."
|
||||
"0 = off; 1 = on; 2 = on+init; 3 = on+init+reinit; 4 = "
|
||||
"on+init+reinit+check."),
|
||||
tensorflow::Flag(
|
||||
"xla_force_host_platform_device_count",
|
||||
int32_setter_for(
|
||||
|
@ -69,6 +69,10 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
|
||||
GemmBackendConfig backend_config =
|
||||
gemm->backend_config<GemmBackendConfig>().ValueOrDie();
|
||||
const int32 cublas_autotune_level =
|
||||
gemm->GetModule()->config().debug_options().xla_gpu_autotune_level();
|
||||
const bool reinit_cublas_data = cublas_autotune_level > 2;
|
||||
const bool check_cublas = cublas_autotune_level > 3;
|
||||
|
||||
VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
|
||||
|
||||
@ -81,7 +85,7 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
for (se::blas::AlgorithmType algorithm : algorithms) {
|
||||
// Make sure the output buffer always has the same value if we use
|
||||
// the bias parameter.
|
||||
if (backend_config.beta() != 0) {
|
||||
if (reinit_cublas_data && backend_config.beta() != 0) {
|
||||
int64 rng_state = 0;
|
||||
InitializeBuffer(stream, gemm->shape().element_type(), &rng_state,
|
||||
output_buffer);
|
||||
@ -114,6 +118,10 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
|
||||
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
|
||||
if (!check_cublas) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
|
||||
allocator.CheckRedzones());
|
||||
@ -248,6 +256,8 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
allocator->GetStream(executor->device_ordinal()));
|
||||
|
||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
||||
const bool init_cublas_data =
|
||||
hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
|
||||
se::RedzoneAllocator input_output_allocator(
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config),
|
||||
/*memory_limit=*/std::numeric_limits<int64>::max());
|
||||
@ -260,7 +270,9 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
|
||||
input_output_allocator.AllocateBytes(
|
||||
ShapeUtil::ByteSizeOf(op->shape())));
|
||||
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
|
||||
if (init_cublas_data) {
|
||||
InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
|
||||
}
|
||||
return buffer;
|
||||
};
|
||||
|
||||
@ -316,7 +328,7 @@ static StatusOr<bool> RunOnComputation(HloComputation* computation,
|
||||
StatusOr<bool> GemmAlgorithmPicker::Run(HloModule* module) {
|
||||
XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker");
|
||||
|
||||
if (module->config().debug_options().xla_gpu_disable_autotune()) {
|
||||
if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
|
||||
VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
|
||||
return false;
|
||||
}
|
||||
|
@ -343,14 +343,19 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||
const Shape& result_shape = instr->shape().tuple_shapes(0);
|
||||
int64 rng_state = 0;
|
||||
|
||||
const auto initialize_buffer = [&stream, &rng_state](
|
||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
||||
const int32 conv_autotune_level =
|
||||
hlo_module_config.debug_options().xla_gpu_autotune_level();
|
||||
const bool init_conv_data = conv_autotune_level > 1;
|
||||
const bool check_conv = conv_autotune_level > 3;
|
||||
const auto initialize_buffer = [init_conv_data, &stream, &rng_state](
|
||||
DeviceMemoryBase buffer,
|
||||
const Shape& buffer_shape) {
|
||||
InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
|
||||
if (init_conv_data) {
|
||||
InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
||||
|
||||
// Allocate space for the input, filter, and output of the convolution.
|
||||
se::RedzoneAllocator input_output_allocator(
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
@ -444,6 +449,10 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
|
||||
if (!check_conv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for writes to redzones.
|
||||
TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
|
||||
CheckRedzones(input_output_allocator, stream,
|
||||
@ -780,7 +789,7 @@ StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
|
||||
StatusOr<bool> GpuConvAlgorithmPicker::Run(HloModule* module) {
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
|
||||
|
||||
if (module->config().debug_options().xla_gpu_disable_autotune()) {
|
||||
if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
|
||||
VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
|
||||
"returning early.";
|
||||
return false;
|
||||
|
@ -972,11 +972,11 @@ xla_test(
|
||||
)
|
||||
|
||||
# Run dot tests with auto-tuning disabled. This just does a basic sanity check
|
||||
# that enabling xla_gpu_disable_autotune does not break simple graphs.
|
||||
# that setting xla_gpu_autotune_level to 0 does not break simple graphs.
|
||||
xla_test(
|
||||
name = "dot_operation_test_autotune_disabled",
|
||||
srcs = ["dot_operation_test.cc"],
|
||||
args = ["--xla_gpu_disable_autotune"],
|
||||
args = ["--xla_gpu_autotune_level=0"],
|
||||
backends = ["gpu"],
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
@ -1151,13 +1151,13 @@ xla_test(
|
||||
)
|
||||
|
||||
# Run convolution tests with auto-tuning disabled. This just does a basic
|
||||
# sanity check that enabling xla_gpu_disable_autotune does not break simple
|
||||
# sanity check that setting xla_gpu_autotune_level to 0 does not break simple
|
||||
# graphs.
|
||||
xla_test(
|
||||
name = "convolution_test_autotune_disabled",
|
||||
timeout = "long",
|
||||
srcs = ["convolution_test.cc"],
|
||||
args = ["--xla_gpu_disable_autotune"],
|
||||
args = ["--xla_gpu_autotune_level=0"],
|
||||
backends = ["gpu"],
|
||||
shard_count = 40,
|
||||
tags = [
|
||||
|
@ -158,7 +158,7 @@ message DebugOptions {
|
||||
bool xla_gpu_crash_on_verification_failures = 101;
|
||||
|
||||
// Disable GEMM and Convolution auto-tuning.
|
||||
bool xla_gpu_disable_autotune = 123;
|
||||
int32 xla_gpu_autotune_level = 123;
|
||||
|
||||
// Force the host platform to pretend that there are these many host
|
||||
// "devices". All these devices are backed by the same threadpool. Defaults
|
||||
|
@ -1646,9 +1646,9 @@ def disable_cudnn_autotune(func):
|
||||
original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE")
|
||||
os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false"
|
||||
original_xla_flags = os.environ.get("XLA_FLAGS")
|
||||
new_xla_flags = "--xla_gpu_disable_autotune"
|
||||
new_xla_flags = "--xla_gpu_autotune_level=0"
|
||||
if original_xla_flags:
|
||||
new_xla_flags += " " + original_xla_flags
|
||||
new_xla_flags = original_xla_flags + " " + new_xla_flags
|
||||
os.environ["XLA_FLAGS"] = new_xla_flags
|
||||
|
||||
result = f(self, *args, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user