From fb9026a766978fee97ce7b3ef53e0774f8b57537 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 22 Apr 2020 20:16:26 -0700 Subject: [PATCH] Have XLA CPU emit TraceMe calls by default in JIT mode (roll forward) This lets Xprof to enable these `TraceMe`s and produce timelines for XLA CPU. Rolling forward with bugfix (earlier we would try to create a `void*` type in LLVM IR, even though LLVM does not allow pointers to void). PiperOrigin-RevId: 307961400 Change-Id: I0624c32294387a113867c2d80de8cccfd6cd6c21 --- tensorflow/compiler/aot/tfcompile.bzl | 11 ++++++++++- tensorflow/compiler/xla/debug_options_flags.cc | 9 ++++++++- tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 14 +++++++------- .../compiler/xla/service/hlo_module_config.h | 4 ++++ tensorflow/compiler/xla/xla.proto | 5 ++++- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 35a054a1aab..abccefbcdbb 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -38,6 +38,7 @@ def tf_library( tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps = True, enable_xla_hlo_profiling = False, + enable_tracemes = False, mlir_components = "None", deps = None, tags = []): @@ -89,6 +90,9 @@ def tf_library( enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, and emit metadata that lets us pretty-print the gathered profile counters. + enable_tracemes: Tell tfcompile to generate calls to + TraceMe::Activity{Start|End} around HLO instructions that can be used by + Xprof to construct profiler timelines. mlir_components: When the value is "None", no components use MLIR. When the value is "Bridge", use MLIR to translate GraphDef to HLO. deps: a list of deps to include on the build rules for the generated @@ -190,6 +194,11 @@ def tf_library( else: profiling_flag = "" + if enable_tracemes: + traceme_flag = "--xla_cpu_enable_xprof_traceme=true" + else: + traceme_flag = "--xla_cpu_enable_xprof_traceme=false" + mlir_flag = "--mlir_components=" + mlir_components srcs = [tfcompile_graph, config] @@ -218,7 +227,7 @@ def tf_library( " --out_header=$(@D)/" + header_file + " --out_metadata_object=$(@D)/" + metadata_object_file + " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag + " " + mlir_flag + " " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag ), tools = [tfcompile_tool], visibility = visibility, diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 8604531889e..e6d60e51e75 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -63,6 +63,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_deterministic_reductions(false); + opts.set_xla_cpu_enable_xprof_traceme(true); + return opts; } @@ -529,7 +531,6 @@ static void AllocateFlags() { flag_values->xla_gpu_algorithm_blacklist_path(), "An AlgorithmBlacklist text proto file as a blacklist " "of convolutions to avoid to use."), - tensorflow::Flag( "xla_gpu_deterministic_reductions", bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions), @@ -545,6 +546,12 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), flag_values->xla_tpu_detect_inf(), "Trigger error on execution on TPU if a INF value is detected"), + tensorflow::Flag( + "xla_cpu_enable_xprof_traceme", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme), + flag_values->xla_cpu_enable_xprof_traceme(), + "If true, XLA CPU generates code to call " + "TraceMe::Activity{Start|End} around HLO operations."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index f4549ac9f3b..c19fa779b60 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -182,11 +182,8 @@ StatusOr IrEmitter::EmitComputation( arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - bool emit_tracing = - hlo_module_config_.hlo_profiling_enabled() && - hlo_module_config_.debug_options().xla_backend_extra_options().count( - "xla_hlo_trace"); - tracing_state_.set_enabled(emit_tracing); + tracing_state_.set_enabled( + computation->parent()->config().cpu_traceme_enabled()); TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); @@ -3126,7 +3123,8 @@ void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b, } llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo(); - llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo(); + llvm::Type* void_ptr_type = + int8_ptr_type; // LLVM does not have a void*, we use an int8* instead. llvm::FunctionType* fn_type = llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type}, /*isVarArg=*/false); @@ -3156,7 +3154,9 @@ void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b, return; } - llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo(); + llvm::Type* void_ptr_type = + b->getInt8Ty()->getPointerTo(); // LLVM does not have a void*, we use an + // int8* instead. llvm::FunctionType* fn_type = llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()}, /*isVarArg=*/false); diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 5b80a6adca2..b31a9ae6ca5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -104,6 +104,10 @@ class HloModuleConfig { return debug_options_.xla_hlo_profile(); } + bool cpu_traceme_enabled() const { + return debug_options_.xla_cpu_enable_xprof_traceme(); + } + // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index c8ba08fc351..826876ed9cb 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -269,7 +269,10 @@ message DebugOptions { bool xla_tpu_detect_nan = 135; bool xla_tpu_detect_inf = 136; - // Next id: 137 + // True if TraceMe annotations are enabled for XLA:CPU. + bool xla_cpu_enable_xprof_traceme = 137; + + // Next id: 138 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.