From e5d12438573d9c4dafb868012f6317099dfb674a Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 22 Dec 2020 15:57:19 -0800 Subject: [PATCH] [TF2XLA] Show stack traces of op definitions for some common tf2xla error messages Eventually it might make sense to always show this. PiperOrigin-RevId: 348707164 Change-Id: Id05656460920763c3e3260695333c41f69386e45 --- .../tf2xla/kernels/tensor_list_ops.cc | 24 ++++++++++++------- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 17 ++++++++++++- tensorflow/compiler/tf2xla/xla_op_kernel.h | 4 ++++ 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 1ea0e797675..cdae14bc11f 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -154,7 +154,8 @@ class TensorListReserveOp : public XlaOpKernel { "XLA compilation requires a fixed tensor list size. Set the number " "of elements. This could also happen if you're using a TensorArray " "in a while loop that does not have its maximum_iteration set, you " - "can fix this by setting maximum_iteration to a suitable value.")); + "can fix this by setting maximum_iteration to a suitable value.", + ctx->StackTrace())); // If element shape is compile time constant and it's not "unknown rank" // shape (-1), create an initialized TensorList. Otherwise create an @@ -224,7 +225,8 @@ class EmptyTensorListOp : public XlaOpKernel { "the max number of elements. This could also happen if " "you're using a TensorArray in a while loop that does not " "have its maximum_iteration set, you can fix this by " - "setting maximum_iteration to a suitable value.")); + "setting maximum_iteration to a suitable value.", + ctx->StackTrace())); if (dtype_ != DT_VARIANT) { // We are creating a non-nested TensorList. @@ -292,7 +294,8 @@ class TensorListElementShapeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, (IsTensorListInitialized(ctx->Input(0), &is_initialized))); OP_REQUIRES(ctx, is_initialized, - errors::InvalidArgument("TensorList is not initialized")); + errors::InvalidArgument("TensorList is not initialized", + ctx->StackTrace())); // Only non-nested TensorList is supported for now. bool is_nested; @@ -348,7 +351,8 @@ class TensorListGetItemOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, (IsTensorListInitialized(ctx->Input(0), &is_initialized))); OP_REQUIRES(ctx, is_initialized, - errors::InvalidArgument("TensorList is not initialized")); + errors::InvalidArgument("TensorList is not initialized", + ctx->StackTrace())); // Only non-nested TensorList is supported for now. bool is_nested; @@ -386,7 +390,8 @@ class TensorListGatherOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, (IsTensorListInitialized(ctx->Input(0), &is_initialized))); OP_REQUIRES(ctx, is_initialized, - errors::InvalidArgument("TensorList is not initialized")); + errors::InvalidArgument("TensorList is not initialized", + ctx->StackTrace())); // Only non-nested TensorList is supported for now. bool is_nested; @@ -437,7 +442,8 @@ class TensorListStackOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, (IsTensorListInitialized(ctx->Input(0), &is_initialized))); OP_REQUIRES(ctx, is_initialized, - errors::InvalidArgument("TensorList is not initialized")); + errors::InvalidArgument("TensorList is not initialized", + ctx->StackTrace())); // Only non-nested TensorList is supported for now. bool is_nested; @@ -468,7 +474,8 @@ class TensorListConcatOp : public XlaOpKernel { bool is_initialized; OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized))); OP_REQUIRES(ctx, is_initialized, - errors::InvalidArgument("TensorList is not initialized")); + errors::InvalidArgument("TensorList is not initialized", + ctx->StackTrace())); // Only non-nested TensorList is supported for now. bool is_nested; @@ -666,7 +673,8 @@ class TensorListPopBackOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, (IsTensorListInitialized(ctx->Input(0), &is_initialized))); OP_REQUIRES(ctx, is_initialized, - errors::InvalidArgument("TensorList is not initialized")); + errors::InvalidArgument("TensorList is not initialized", + ctx->StackTrace())); xla::XlaOp list = ctx->Input(0); xla::XlaOp list_result, element_result; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 1d382fe5b9c..b1468c0d8f9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -177,7 +177,8 @@ Status XlaOpKernelContext::ConstantInputReshaped( "This error means that a shape or dimension argument could not be " "evaluated at compile time, usually because the value of the argument " "depends on a parameter to the computation, on a variable, or on a " - "stateful operation such as a random number generator."); + "stateful operation such as a random number generator.", + StackTrace()); } Tensor temp(constant->dtype()); @@ -705,4 +706,18 @@ void XlaOpKernel::Compute(OpKernelContext* context) { Compile(&xla_context); } +std::string XlaOpKernelContext::StackTrace() const { + if (const AbstractStackTrace* stack_trace = + xla_context()->StackTraceForNodeName(op_kernel().name())) { + AbstractStackTrace::TracePrintingOptions opts; + opts.show_line_contents = true; + opts.filter_common_prefix = true; + opts.drop_internal_frames = true; + return absl::StrCat("\nStack trace for op definition: \n", + stack_trace->ToString(opts), "\n"); + } else { + return ""; + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 1ed343ba20f..04a2b83f76e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -290,6 +290,10 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::XlaComputation* GetOrCreateMul(const DataType type); + // Returns stack trace encoded as a string at a given module, or an empty + // string if none found. + std::string StackTrace() const; + private: // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name);