From 29ebd644a2656bb37be716c5b8cd68b9c0c596be Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 10 Dec 2020 16:44:59 -0800 Subject: [PATCH] [TF2XLA] Display Python location information in XLA metadata As a heuristic, we show the last frame which does not have "tensorflow/python" substring. PiperOrigin-RevId: 346892818 Change-Id: I72b5953ba5de9c0a10648805a852fcc033dd960b --- .../xla/transforms/legalize_tf_with_tf2xla.cc | 3 +- .../compiler/tf2xla/xla_compilation_device.cc | 19 +++++++++++- tensorflow/compiler/tf2xla/xla_compiler.cc | 2 +- tensorflow/compiler/tf2xla/xla_context.cc | 11 +++++-- tensorflow/compiler/tf2xla/xla_context.h | 15 +++++++++- .../common_runtime/inline_function_utils.cc | 1 + tensorflow/core/graph/graph.cc | 1 + .../python/eager/def_function_xla_jit_test.py | 30 +++++++++++++++++++ tensorflow/python/util/tf_stack.cc | 2 +- 9 files changed, 77 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 3f3f961595d..45334c92179 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -352,7 +352,8 @@ LogicalResult Tf2XlaRewriter::PrepareParams() { // XlaCompiler within the context is only used by the functional ops to // compile functions. We are not handling those at the moment so XlaCompiler // is not required. - context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_); + context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_, + /*graph=*/nullptr); context_->Ref(); device_mgr_ = CreateDeviceMgr(device_type_); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 06423019f23..0de00581a2f 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -83,14 +83,31 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) { return allocator_.get(); } +// Attaches location from the node stack trace to metadata. As a heuristic, +// picks the last frame which does not contain the "tensorflow/python" substring +// (making exception for frames containing "test" to allow for testing the +// feature). +static void AttachLocationToMetadata(xla::OpMetadata& metadata, + OpKernel* op_kernel, XlaContext& context) { + if (const AbstractStackTrace* stack_trace = + context.StackTraceForNodeName(op_kernel->def().name())) { + if (absl::optional frame = stack_trace->LastUserFrame()) { + metadata.set_source_file(frame->file_name); + metadata.set_source_line(frame->line_number); + } + } +} + void XlaCompilationDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(4) << "XlaCompilationDevice::Compute " << FormatNodeDefForError(op_kernel->def()); - auto* b = XlaContext::Get(context).builder(); + XlaContext& xla_context = XlaContext::Get(context); + auto* b = xla_context.builder(); xla::OpMetadata metadata; metadata.set_op_type(op_kernel->type_string()); metadata.set_op_name(op_kernel->name()); + AttachLocationToMetadata(metadata, op_kernel, xla_context); b->SetOpMetadata(metadata); auto sharding_parse_result = ParseShardingFromDevice( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 716146ec035..f55f9973853 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -1309,7 +1309,7 @@ Status XlaCompiler::CompileGraph( options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext(this, &builder); + XlaContext* context = new XlaContext(this, &builder, graph.get()); core::ScopedUnref context_unref(context); std::vector real_args(args.begin(), args.end()); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index cb5bf34208f..7e81644ae40 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -57,8 +57,15 @@ void XlaContext::set_args(std::vector args) { args_ = std::move(args); } -XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) - : compiler_(compiler), builder_(builder) {} +XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, + const Graph* graph) + : compiler_(compiler), builder_(builder) { + if (graph) { + for (const Node* node : graph->nodes()) { + stack_traces_[node->name()] = node->GetStackTrace(); + } + } +} string XlaContext::DebugString() const { return "XLA JIT context"; } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index e44ac05b702..8376471da89 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -44,13 +45,22 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. - XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, + const Graph* graph); // Virtual method defined by ResourceBase. string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } + const AbstractStackTrace* StackTraceForNodeName(const std::string& name) { + const auto& it = stack_traces_.find(name); + if (it != stack_traces_.end()) { + return it->second.get(); + } + return nullptr; + } + // Returns the XlaBuilder that Ops use for compiling new expressions. xla::XlaBuilder* builder() { return builder_; } @@ -100,6 +110,9 @@ class XlaContext : public ResourceBase { // The XlaBuilder used to construct the subgraph's compiled representation. xla::XlaBuilder* builder_; + // Stack traces for the graph used for compilation. + StackTracesMap stack_traces_; + // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc index fc2b846401f..3bff26b6576 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.cc +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -616,6 +616,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, Node* clone = g->AddNode(ndef, &added_node); TF_CHECK_OK(added_node); node_map[n->id()] = clone; + clone->SetStackTrace(n->GetStackTrace()); // If there is an input control node, and one of: // a) the node has no data or control inputs, or diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 93f4eaf624e..2e7e7fbf4c3 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -454,6 +454,7 @@ Node* Graph::CopyNode(const Node* node) { copy->MaybeCopyOnWrite(); copy->props_->op_def = op_def; } + copy->SetStackTrace(node->GetStackTrace()); return copy; } diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index f43efc90d47..5abb88bc399 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -170,6 +170,36 @@ class DefFunctionTest(xla_test.XLATestCase): 'not compilable'): xla_func(inputs) + @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' + 'support stack traces') + def testPythonLocationInMetadata(self): + with ops.device('device:{}:0'.format(self.device)): + + @def_function.function(jit_compile=True) + def fn(x, y): + return x + y + + inputs = constant_op.constant([1, 2, 2, 3, 3]) + self.assertIn('def_function_xla_jit_test', + fn.experimental_get_compiler_ir(inputs, inputs)()) + + @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' + 'support stack traces') + def testPythonLocationNestedInMetadata(self): + with ops.device('device:{}:0'.format(self.device)): + + @def_function.function(jit_compile=True) + def f(x, y): + return x + y + + @def_function.function(jit_compile=True) + def g(x, y): + return f(x, y) + + inputs = constant_op.constant([1, 2, 2, 3, 3]) + self.assertIn('def_function_xla_jit_test', + g.experimental_get_compiler_ir(inputs, inputs)()) + @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' 'support stack traces') def testPythonStackTrace(self): diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 454f07aa52f..ad43925fac8 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -177,11 +177,11 @@ class StackTraceWrapper : public AbstractStackTrace { void GenerateCache() const { // Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and // 2) ToStackFrames and LineContents actually need it. - PyGILState_STATE state = PyGILState_Ensure(); if (stack_frames_cache_) { return; } + PyGILState_STATE state = PyGILState_Ensure(); absl::flat_hash_map, StackFrame> m; absl::flat_hash_set f;