[TF2XLA] Display Python location information in XLA metadata

As a heuristic, we show the last frame which does not have "tensorflow/python"
substring.

PiperOrigin-RevId: 346892818
Change-Id: I72b5953ba5de9c0a10648805a852fcc033dd960b
This commit is contained in:
George Karpenkov 2020-12-10 16:44:59 -08:00 committed by TensorFlower Gardener
parent 2948461bab
commit 29ebd644a2
9 changed files with 77 additions and 7 deletions

View File

@ -352,7 +352,8 @@ LogicalResult Tf2XlaRewriter::PrepareParams() {
// XlaCompiler within the context is only used by the functional ops to
// compile functions. We are not handling those at the moment so XlaCompiler
// is not required.
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
/*graph=*/nullptr);
context_->Ref();
device_mgr_ = CreateDeviceMgr(device_type_);

View File

@ -83,14 +83,31 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) {
return allocator_.get();
}
// Attaches location from the node stack trace to metadata. As a heuristic,
// picks the last frame which does not contain the "tensorflow/python" substring
// (making exception for frames containing "test" to allow for testing the
// feature).
static void AttachLocationToMetadata(xla::OpMetadata& metadata,
OpKernel* op_kernel, XlaContext& context) {
if (const AbstractStackTrace* stack_trace =
context.StackTraceForNodeName(op_kernel->def().name())) {
if (absl::optional<StackFrame> frame = stack_trace->LastUserFrame()) {
metadata.set_source_file(frame->file_name);
metadata.set_source_line(frame->line_number);
}
}
}
void XlaCompilationDevice::Compute(OpKernel* op_kernel,
OpKernelContext* context) {
VLOG(4) << "XlaCompilationDevice::Compute "
<< FormatNodeDefForError(op_kernel->def());
auto* b = XlaContext::Get(context).builder();
XlaContext& xla_context = XlaContext::Get(context);
auto* b = xla_context.builder();
xla::OpMetadata metadata;
metadata.set_op_type(op_kernel->type_string());
metadata.set_op_name(op_kernel->name());
AttachLocationToMetadata(metadata, op_kernel, xla_context);
b->SetOpMetadata(metadata);
auto sharding_parse_result = ParseShardingFromDevice(

View File

@ -1309,7 +1309,7 @@ Status XlaCompiler::CompileGraph(
options_.device_type, name));
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(this, &builder);
XlaContext* context = new XlaContext(this, &builder, graph.get());
core::ScopedUnref context_unref(context);
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());

View File

@ -57,8 +57,15 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
args_ = std::move(args);
}
XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
: compiler_(compiler), builder_(builder) {}
XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
const Graph* graph)
: compiler_(compiler), builder_(builder) {
if (graph) {
for (const Node* node : graph->nodes()) {
stack_traces_[node->name()] = node->GetStackTrace();
}
}
}
string XlaContext::DebugString() const { return "XLA JIT context"; }

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@ -44,13 +45,22 @@ class XlaContext : public ResourceBase {
// Creates a new XlaContext. See the documentation on the class data fields
// for descriptions of the arguments.
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder);
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
const Graph* graph);
// Virtual method defined by ResourceBase.
string DebugString() const override;
XlaCompiler* compiler() const { return compiler_; }
const AbstractStackTrace* StackTraceForNodeName(const std::string& name) {
const auto& it = stack_traces_.find(name);
if (it != stack_traces_.end()) {
return it->second.get();
}
return nullptr;
}
// Returns the XlaBuilder that Ops use for compiling new expressions.
xla::XlaBuilder* builder() { return builder_; }
@ -100,6 +110,9 @@ class XlaContext : public ResourceBase {
// The XlaBuilder used to construct the subgraph's compiled representation.
xla::XlaBuilder* builder_;
// Stack traces for the graph used for compilation.
StackTracesMap stack_traces_;
// Arguments to the Tensorflow graph, indexed by _Arg index.
// Includes both compile-time constant arguments and runtime parameters.
std::vector<XlaExpression> args_;

View File

@ -616,6 +616,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
Node* clone = g->AddNode(ndef, &added_node);
TF_CHECK_OK(added_node);
node_map[n->id()] = clone;
clone->SetStackTrace(n->GetStackTrace());
// If there is an input control node, and one of:
// a) the node has no data or control inputs, or

View File

@ -454,6 +454,7 @@ Node* Graph::CopyNode(const Node* node) {
copy->MaybeCopyOnWrite();
copy->props_->op_def = op_def;
}
copy->SetStackTrace(node->GetStackTrace());
return copy;
}

View File

@ -170,6 +170,36 @@ class DefFunctionTest(xla_test.XLATestCase):
'not compilable'):
xla_func(inputs)
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testPythonLocationInMetadata(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(jit_compile=True)
def fn(x, y):
return x + y
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertIn('def_function_xla_jit_test',
fn.experimental_get_compiler_ir(inputs, inputs)())
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testPythonLocationNestedInMetadata(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(jit_compile=True)
def f(x, y):
return x + y
@def_function.function(jit_compile=True)
def g(x, y):
return f(x, y)
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertIn('def_function_xla_jit_test',
g.experimental_get_compiler_ir(inputs, inputs)())
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testPythonStackTrace(self):

View File

@ -177,11 +177,11 @@ class StackTraceWrapper : public AbstractStackTrace {
void GenerateCache() const {
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
// 2) ToStackFrames and LineContents actually need it.
PyGILState_STATE state = PyGILState_Ensure();
if (stack_frames_cache_) {
return;
}
PyGILState_STATE state = PyGILState_Ensure();
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
absl::flat_hash_set<std::string> f;