Remove _FallbackError from the stack trace

This can be confusing since it includes the full traceback of the fastpath
eager call.

The stack trace still includes <op>_fallback_error call to the function.

PiperOrigin-RevId: 304515138
Change-Id: I91f068fe3fe2596089da8d4cd0118782bc20fdd5
This commit is contained in:
Akshay Modi 2020-04-02 17:56:50 -07:00 committed by TensorFlower Gardener
parent bfcaa9b762
commit 1ca0d81b0a
2 changed files with 30 additions and 12 deletions

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import traceback
import numpy as np
from tensorflow.python import pywrap_tfe
@ -28,6 +31,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
@ -332,6 +336,18 @@ class Tests(test.TestCase):
# TODO(b/147828820): Converting with tensors should work.
# _ = ops.EagerTensor([[t]], device=ctx.device_name, dtype=None)
def testFallbackErrorNotVisibleWhenFallbackMethodRaises(self):
ctx = context.context()
ctx.ensure_initialized()
try:
math_ops.mat_mul([[1., 1.] * 2], [[1., 1.] * 3])
except errors.InvalidArgumentError:
etype, value, tb = sys.exc_info()
full_exception_text = " ".join(
traceback.format_exception(etype, value, tb))
self.assertNotRegex(full_exception_text, "_FallbackException")
if __name__ == "__main__":
test.main()

View File

@ -806,18 +806,6 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
// Handle fallback.
if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
strings::StrAppend(&fallback_params, "ctx=_ctx");
strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
strings::StrAppend(&result_, " try:\n");
strings::StrAppend(
&result_, " ", "return ", function_name_, kEagerFallbackSuffix,
"(\n",
WordWrap(strings::StrCat(" "),
strings::StrCat(fallback_params, ")"), kRightMargin),
"\n");
strings::StrAppend(&result_, " except _core._SymbolicException:\n");
strings::StrAppend(&result_,
" pass # Add nodes to the TensorFlow graph.\n");
AddDispatch(" ");
// Any errors thrown from execute need to be unwrapped from
// _NotOkStatusException.
@ -825,6 +813,20 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
"except _core._NotOkStatusException as e:\n");
strings::StrAppend(&result_, " ",
"_ops.raise_from_not_ok_status(e, name)\n");
strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
strings::StrAppend(&result_, " pass\n");
strings::StrAppend(&result_, " try:\n");
strings::StrAppend(
&result_, " ", "return ", function_name_, kEagerFallbackSuffix,
"(\n",
WordWrap(strings::StrCat(" "),
strings::StrCat(fallback_params, ")"), kRightMargin),
"\n");
strings::StrAppend(&result_, " except _core._SymbolicException:\n");
strings::StrAppend(&result_,
" pass # Add nodes to the TensorFlow graph.\n");
AddDispatch(" ");
}
void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {