[TF2XLA] Display Python location information in XLA metadata
As a heuristic, we show the last frame which does not have "tensorflow/python" substring. PiperOrigin-RevId: 346892818 Change-Id: I72b5953ba5de9c0a10648805a852fcc033dd960b
This commit is contained in:
parent
2948461bab
commit
29ebd644a2
@ -352,7 +352,8 @@ LogicalResult Tf2XlaRewriter::PrepareParams() {
|
|||||||
// XlaCompiler within the context is only used by the functional ops to
|
// XlaCompiler within the context is only used by the functional ops to
|
||||||
// compile functions. We are not handling those at the moment so XlaCompiler
|
// compile functions. We are not handling those at the moment so XlaCompiler
|
||||||
// is not required.
|
// is not required.
|
||||||
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
|
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
|
||||||
|
/*graph=*/nullptr);
|
||||||
context_->Ref();
|
context_->Ref();
|
||||||
|
|
||||||
device_mgr_ = CreateDeviceMgr(device_type_);
|
device_mgr_ = CreateDeviceMgr(device_type_);
|
||||||
|
@ -83,14 +83,31 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) {
|
|||||||
return allocator_.get();
|
return allocator_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attaches location from the node stack trace to metadata. As a heuristic,
|
||||||
|
// picks the last frame which does not contain the "tensorflow/python" substring
|
||||||
|
// (making exception for frames containing "test" to allow for testing the
|
||||||
|
// feature).
|
||||||
|
static void AttachLocationToMetadata(xla::OpMetadata& metadata,
|
||||||
|
OpKernel* op_kernel, XlaContext& context) {
|
||||||
|
if (const AbstractStackTrace* stack_trace =
|
||||||
|
context.StackTraceForNodeName(op_kernel->def().name())) {
|
||||||
|
if (absl::optional<StackFrame> frame = stack_trace->LastUserFrame()) {
|
||||||
|
metadata.set_source_file(frame->file_name);
|
||||||
|
metadata.set_source_line(frame->line_number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
||||||
OpKernelContext* context) {
|
OpKernelContext* context) {
|
||||||
VLOG(4) << "XlaCompilationDevice::Compute "
|
VLOG(4) << "XlaCompilationDevice::Compute "
|
||||||
<< FormatNodeDefForError(op_kernel->def());
|
<< FormatNodeDefForError(op_kernel->def());
|
||||||
auto* b = XlaContext::Get(context).builder();
|
XlaContext& xla_context = XlaContext::Get(context);
|
||||||
|
auto* b = xla_context.builder();
|
||||||
xla::OpMetadata metadata;
|
xla::OpMetadata metadata;
|
||||||
metadata.set_op_type(op_kernel->type_string());
|
metadata.set_op_type(op_kernel->type_string());
|
||||||
metadata.set_op_name(op_kernel->name());
|
metadata.set_op_name(op_kernel->name());
|
||||||
|
AttachLocationToMetadata(metadata, op_kernel, xla_context);
|
||||||
b->SetOpMetadata(metadata);
|
b->SetOpMetadata(metadata);
|
||||||
|
|
||||||
auto sharding_parse_result = ParseShardingFromDevice(
|
auto sharding_parse_result = ParseShardingFromDevice(
|
||||||
|
@ -1309,7 +1309,7 @@ Status XlaCompiler::CompileGraph(
|
|||||||
options_.device_type, name));
|
options_.device_type, name));
|
||||||
|
|
||||||
xla::XlaBuilder builder(name);
|
xla::XlaBuilder builder(name);
|
||||||
XlaContext* context = new XlaContext(this, &builder);
|
XlaContext* context = new XlaContext(this, &builder, graph.get());
|
||||||
core::ScopedUnref context_unref(context);
|
core::ScopedUnref context_unref(context);
|
||||||
|
|
||||||
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
|
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
|
||||||
|
@ -57,8 +57,15 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
|
|||||||
args_ = std::move(args);
|
args_ = std::move(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
|
XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||||
: compiler_(compiler), builder_(builder) {}
|
const Graph* graph)
|
||||||
|
: compiler_(compiler), builder_(builder) {
|
||||||
|
if (graph) {
|
||||||
|
for (const Node* node : graph->nodes()) {
|
||||||
|
stack_traces_[node->name()] = node->GetStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
string XlaContext::DebugString() const { return "XLA JIT context"; }
|
string XlaContext::DebugString() const { return "XLA JIT context"; }
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -44,13 +45,22 @@ class XlaContext : public ResourceBase {
|
|||||||
|
|
||||||
// Creates a new XlaContext. See the documentation on the class data fields
|
// Creates a new XlaContext. See the documentation on the class data fields
|
||||||
// for descriptions of the arguments.
|
// for descriptions of the arguments.
|
||||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder);
|
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||||
|
const Graph* graph);
|
||||||
|
|
||||||
// Virtual method defined by ResourceBase.
|
// Virtual method defined by ResourceBase.
|
||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
|
|
||||||
XlaCompiler* compiler() const { return compiler_; }
|
XlaCompiler* compiler() const { return compiler_; }
|
||||||
|
|
||||||
|
const AbstractStackTrace* StackTraceForNodeName(const std::string& name) {
|
||||||
|
const auto& it = stack_traces_.find(name);
|
||||||
|
if (it != stack_traces_.end()) {
|
||||||
|
return it->second.get();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the XlaBuilder that Ops use for compiling new expressions.
|
// Returns the XlaBuilder that Ops use for compiling new expressions.
|
||||||
xla::XlaBuilder* builder() { return builder_; }
|
xla::XlaBuilder* builder() { return builder_; }
|
||||||
|
|
||||||
@ -100,6 +110,9 @@ class XlaContext : public ResourceBase {
|
|||||||
// The XlaBuilder used to construct the subgraph's compiled representation.
|
// The XlaBuilder used to construct the subgraph's compiled representation.
|
||||||
xla::XlaBuilder* builder_;
|
xla::XlaBuilder* builder_;
|
||||||
|
|
||||||
|
// Stack traces for the graph used for compilation.
|
||||||
|
StackTracesMap stack_traces_;
|
||||||
|
|
||||||
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
// Arguments to the Tensorflow graph, indexed by _Arg index.
|
||||||
// Includes both compile-time constant arguments and runtime parameters.
|
// Includes both compile-time constant arguments and runtime parameters.
|
||||||
std::vector<XlaExpression> args_;
|
std::vector<XlaExpression> args_;
|
||||||
|
@ -616,6 +616,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
|||||||
Node* clone = g->AddNode(ndef, &added_node);
|
Node* clone = g->AddNode(ndef, &added_node);
|
||||||
TF_CHECK_OK(added_node);
|
TF_CHECK_OK(added_node);
|
||||||
node_map[n->id()] = clone;
|
node_map[n->id()] = clone;
|
||||||
|
clone->SetStackTrace(n->GetStackTrace());
|
||||||
|
|
||||||
// If there is an input control node, and one of:
|
// If there is an input control node, and one of:
|
||||||
// a) the node has no data or control inputs, or
|
// a) the node has no data or control inputs, or
|
||||||
|
@ -454,6 +454,7 @@ Node* Graph::CopyNode(const Node* node) {
|
|||||||
copy->MaybeCopyOnWrite();
|
copy->MaybeCopyOnWrite();
|
||||||
copy->props_->op_def = op_def;
|
copy->props_->op_def = op_def;
|
||||||
}
|
}
|
||||||
|
copy->SetStackTrace(node->GetStackTrace());
|
||||||
|
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
@ -170,6 +170,36 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
'not compilable'):
|
'not compilable'):
|
||||||
xla_func(inputs)
|
xla_func(inputs)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||||
|
'support stack traces')
|
||||||
|
def testPythonLocationInMetadata(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def fn(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||||
|
self.assertIn('def_function_xla_jit_test',
|
||||||
|
fn.experimental_get_compiler_ir(inputs, inputs)())
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||||
|
'support stack traces')
|
||||||
|
def testPythonLocationNestedInMetadata(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def f(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def g(x, y):
|
||||||
|
return f(x, y)
|
||||||
|
|
||||||
|
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||||
|
self.assertIn('def_function_xla_jit_test',
|
||||||
|
g.experimental_get_compiler_ir(inputs, inputs)())
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||||
'support stack traces')
|
'support stack traces')
|
||||||
def testPythonStackTrace(self):
|
def testPythonStackTrace(self):
|
||||||
|
@ -177,11 +177,11 @@ class StackTraceWrapper : public AbstractStackTrace {
|
|||||||
void GenerateCache() const {
|
void GenerateCache() const {
|
||||||
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
|
// Grabbing the GIL solves two purposes: 1) makes the class thread-safe, and
|
||||||
// 2) ToStackFrames and LineContents actually need it.
|
// 2) ToStackFrames and LineContents actually need it.
|
||||||
PyGILState_STATE state = PyGILState_Ensure();
|
|
||||||
if (stack_frames_cache_) {
|
if (stack_frames_cache_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyGILState_STATE state = PyGILState_Ensure();
|
||||||
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
|
absl::flat_hash_map<std::pair<std::string, int>, StackFrame> m;
|
||||||
absl::flat_hash_set<std::string> f;
|
absl::flat_hash_set<std::string> f;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user