From 2c596172d2b92fabf3e0c378354160c99ac909cb Mon Sep 17 00:00:00 2001
From: George Karpenkov <cheshire@google.com>
Date: Mon, 8 Feb 2021 18:24:06 -0800
Subject: [PATCH] [TF2XLA] Attach Python stack trace to all INVALID_ARGUMENT
 Status objects created from the bridge

PiperOrigin-RevId: 356400325
Change-Id: I1b8fc2a771bcb6492ccde5fbeb14844154b38791
---
 .../tf2xla/kernels/tensor_list_ops.cc         | 24 +++++++------------
 tensorflow/compiler/tf2xla/xla_op_kernel.cc   | 21 +++++++++++-----
 2 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
index cdae14bc11f..1ea0e797675 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -154,8 +154,7 @@ class TensorListReserveOp : public XlaOpKernel {
             "XLA compilation requires a fixed tensor list size. Set the number "
             "of elements. This could also happen if you're using a TensorArray "
             "in a while loop that does not have its maximum_iteration set, you "
-            "can fix this by setting maximum_iteration to a suitable value.",
-            ctx->StackTrace()));
+            "can fix this by setting maximum_iteration to a suitable value."));
 
     // If element shape is compile time constant and it's not "unknown rank"
     // shape (-1), create an initialized TensorList. Otherwise create an
@@ -225,8 +224,7 @@ class EmptyTensorListOp : public XlaOpKernel {
                     "the max number of elements. This could also happen if "
                     "you're using a TensorArray in a while loop that does not "
                     "have its maximum_iteration set, you can fix this by "
-                    "setting maximum_iteration to a suitable value.",
-                    ctx->StackTrace()));
+                    "setting maximum_iteration to a suitable value."));
 
     if (dtype_ != DT_VARIANT) {
       // We are creating a non-nested TensorList.
@@ -294,8 +292,7 @@ class TensorListElementShapeOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx,
                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
     OP_REQUIRES(ctx, is_initialized,
-                errors::InvalidArgument("TensorList is not initialized",
-                                        ctx->StackTrace()));
+                errors::InvalidArgument("TensorList is not initialized"));
 
     // Only non-nested TensorList is supported for now.
     bool is_nested;
@@ -351,8 +348,7 @@ class TensorListGetItemOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx,
                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
     OP_REQUIRES(ctx, is_initialized,
-                errors::InvalidArgument("TensorList is not initialized",
-                                        ctx->StackTrace()));
+                errors::InvalidArgument("TensorList is not initialized"));
 
     // Only non-nested TensorList is supported for now.
     bool is_nested;
@@ -390,8 +386,7 @@ class TensorListGatherOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx,
                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
     OP_REQUIRES(ctx, is_initialized,
-                errors::InvalidArgument("TensorList is not initialized",
-                                        ctx->StackTrace()));
+                errors::InvalidArgument("TensorList is not initialized"));
 
     // Only non-nested TensorList is supported for now.
     bool is_nested;
@@ -442,8 +437,7 @@ class TensorListStackOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx,
                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
     OP_REQUIRES(ctx, is_initialized,
-                errors::InvalidArgument("TensorList is not initialized",
-                                        ctx->StackTrace()));
+                errors::InvalidArgument("TensorList is not initialized"));
 
     // Only non-nested TensorList is supported for now.
     bool is_nested;
@@ -474,8 +468,7 @@ class TensorListConcatOp : public XlaOpKernel {
     bool is_initialized;
     OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized)));
     OP_REQUIRES(ctx, is_initialized,
-                errors::InvalidArgument("TensorList is not initialized",
-                                        ctx->StackTrace()));
+                errors::InvalidArgument("TensorList is not initialized"));
 
     // Only non-nested TensorList is supported for now.
     bool is_nested;
@@ -673,8 +666,7 @@ class TensorListPopBackOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx,
                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
     OP_REQUIRES(ctx, is_initialized,
-                errors::InvalidArgument("TensorList is not initialized",
-                                        ctx->StackTrace()));
+                errors::InvalidArgument("TensorList is not initialized"));
 
     xla::XlaOp list = ctx->Input(0);
     xla::XlaOp list_result, element_result;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index bc80a6bf12b..0a298b798be 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -177,8 +177,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
         "This error means that a shape or dimension argument could not be "
         "evaluated at compile time, usually because the value of the argument "
         "depends on a parameter to the computation, on a variable, or on a "
-        "stateful operation such as a random number generator.",
-        StackTrace());
+        "stateful operation such as a random number generator.");
   }
 
   Tensor temp(constant->dtype());
@@ -664,19 +663,29 @@ Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
                               builder());
 }
 
+static Status GetStatusWithStackTrace(const Status& s,
+                                      const XlaOpKernelContext* ctx) {
+  if (s.code() == error::INVALID_ARGUMENT) {
+    return Status{s.code(),
+                  absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
+  }
+  return s;
+}
+
 void XlaOpKernelContext::CtxFailure(const Status& s) {
-  context_->CtxFailure(s);
+  context_->CtxFailure(GetStatusWithStackTrace(s, this));
 }
 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
-  context_->CtxFailureWithWarning(s);
+  context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
 }
+
 void XlaOpKernelContext::CtxFailure(const char* file, int line,
                                     const Status& s) {
-  context_->CtxFailure(file, line, s);
+  context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
 }
 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
                                                const Status& s) {
-  context_->CtxFailureWithWarning(file, line, s);
+  context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
 }
 
 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(