[TF2XLA] Display Python location information in XLA metadata
As a heuristic, we show the last frame which does not have "tensorflow/python" substring. PiperOrigin-RevId: 346892818 Change-Id: I72b5953ba5de9c0a10648805a852fcc033dd960b
This commit is contained in:
parent
2948461bab
commit
29ebd644a2
tensorflow
compiler
mlir/xla/transforms
tf2xla
core
python
@ -352,7 +352,8 @@ LogicalResult Tf2XlaRewriter::PrepareParams() {
|
||||
// XlaCompiler within the context is only used by the functional ops to
|
||||
// compile functions. We are not handling those at the moment so XlaCompiler
|
||||
// is not required.
|
||||
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
|
||||
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
|
||||
/*graph=*/nullptr);
|
||||
context_->Ref();
|
||||
|
||||
device_mgr_ = CreateDeviceMgr(device_type_);
|
||||
|
@ -83,14 +83,31 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
return allocator_.get();
|
||||
}
|
||||
|
||||
// Attaches location from the node stack trace to metadata. As a heuristic,
|
||||
// picks the last frame which does not contain the "tensorflow/python" substring
|
||||
// (making exception for frames containing "test" to allow for testing the
|
||||
// feature).
|
||||
static void AttachLocationToMetadata(xla::OpMetadata& metadata,
|
||||
OpKernel* op_kernel, XlaContext& context) {
|
||||
if (const AbstractStackTrace* stack_trace =
|
||||
context.StackTraceForNodeName(op_kernel->def().name())) {
|
||||
if (absl::optional<StackFrame> frame = stack_trace->LastUserFrame()) {
|
||||
metadata.set_source_file(frame->file_name);
|
||||
metadata.set_source_line(frame->line_number);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
||||
OpKernelContext* context) {
|
||||
VLOG(4) << "XlaCompilationDevice::Compute "
|
||||
<< FormatNodeDefForError(op_kernel->def());
|
||||
auto* b = XlaContext::Get(context).builder();
|
||||
XlaContext& xla_context = XlaContext::Get(context);
|
||||
auto* b = xla_context.builder();
|
||||
xla::OpMetadata metadata;
|
||||
metadata.set_op_type(op_kernel->type_string());
|
||||
metadata.set_op_name(op_kernel->name());
|
||||
AttachLocationToMetadata(metadata, op_kernel, xla_context);
|
||||
b->SetOpMetadata(metadata);
|
||||
|
||||
auto sharding_parse_result = ParseShardingFromDevice(
|
||||
|
@ -1309,7 +1309,7 @@ Status XlaCompiler::CompileGraph(
|
||||
options_.device_type, name));
|
||||
|
||||
xla::XlaBuilder builder(name);
|
||||
XlaContext* context = new XlaContext(this, &builder);
|
||||
XlaContext* context = new XlaContext(this, &builder, graph.get());
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
|
||||
|
@ -57,8 +57,15 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
|
||||
args_ = std::move(args);
|
||||
}
|
||||
|
||||
XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
|
||||
: compiler_(compiler), builder_(builder) {}
|
||||
XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
const Graph* graph)
|
||||
: compiler_(compiler), builder_(builder) {
|
||||
if (graph) {
|
||||
for (const Node* node : graph->nodes()) {
|
||||
stack_traces_[node->name()] = node->GetStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string XlaContext::DebugString() const { return "XLA JIT context"; }
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -44,13 +45,22 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
// Creates a new XlaContext. See the documentation on the class data fields
|
||||
// for descriptions of the arguments.
|
||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder);
|
||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
const Graph* graph);
|
||||
|
||||
// Virtual method defined by ResourceBase.
|
||||
string DebugString() const override;
|
||||
|
||||
XlaCompiler* compiler() const { return compiler_; }
|
||||
|
||||
const AbstractStackTrace* StackTraceForNodeName(const std::string& name) {
|
||||
const auto& it = stack_traces_.find(name);
|
||||
if (it != stack_traces_.end()) {
|
||||
return it->second.get();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Returns the XlaBuilder that Ops use for compiling new expressions.
|
||||
xla::XlaBuilder* builder() { return builder_; }
|
||||
|
||||
@ -100,6 +110,9 @@ class XlaContext : public ResourceBase {
|
||||
// The XlaBuilder used to construct the subgraph's compiled representation.
|
||||
xla::XlaBuilder* builder_;
|
||||
|
||||
// Stack traces for the graph used for compilation.
|
||||
StackTracesMap stack_traces_;
|
||||
|
||||
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
||||
// Includes both compile-time constant arguments and runtime parameters.
|
||||
std::vector<XlaExpression> args_;
|
||||
|
@ -616,6 +616,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
Node* clone = g->AddNode(ndef, &added_node);
|
||||
TF_CHECK_OK(added_node);
|
||||
node_map[n->id()] = clone;
|
||||
clone->SetStackTrace(n->GetStackTrace());
|
||||
|
||||
// If there is an input control node, and one of:
|
||||
// a) the node has no data or control inputs, or
|
||||
|
@ -454,6 +454,7 @@ Node* Graph::CopyNode(const Node* node) {
|
||||
copy->MaybeCopyOnWrite();
|
||||
copy->props_->op_def = op_def;
|
||||
}
|
||||
copy->SetStackTrace(node->GetStackTrace());
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
@ -170,6 +170,36 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
'not compilable'):
|
||||
xla_func(inputs)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testPythonLocationInMetadata(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertIn('def_function_xla_jit_test',
|
||||
fn.experimental_get_compiler_ir(inputs, inputs)())
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testPythonLocationNestedInMetadata(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def g(x, y):
|
||||
return f(x, y)
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertIn('def_function_xla_jit_test',
|
||||
g.experimental_get_compiler_ir(inputs, inputs)())
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testPythonStackTrace(self):
|
||||
|
@ -177,11 +177,11 @@ class StackTraceWrapper : public AbstractStackTrace {
|
||||
void GenerateCache() const {
|
||||
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
|
||||
// 2) ToStackFrames and LineContents actually need it.
|
||||
PyGILState_STATE state = PyGILState_Ensure();
|
||||
if (stack_frames_cache_) {
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_STATE state = PyGILState_Ensure();
|
||||
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
|
||||
absl::flat_hash_set<std::string> f;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user