Py_BuildValue might fail if the const char* cannot be interpreted as utf-8
PiperOrigin-RevId: 218198785
This commit is contained in:
parent
3d715da989
commit
4a7b57e93d
@ -250,6 +250,7 @@ bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
|
|||||||
tensorflow::Safe_PyObjectPtr py_type_enum(
|
tensorflow::Safe_PyObjectPtr py_type_enum(
|
||||||
PyObject_GetAttrString(py_value, "_type_enum"));
|
PyObject_GetAttrString(py_value, "_type_enum"));
|
||||||
if (py_type_enum == nullptr) {
|
if (py_type_enum == nullptr) {
|
||||||
|
PyErr_Clear();
|
||||||
TF_SetStatus(
|
TF_SetStatus(
|
||||||
status, TF_INVALID_ARGUMENT,
|
status, TF_INVALID_ARGUMENT,
|
||||||
tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
|
tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
|
||||||
@ -795,6 +796,13 @@ int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
|
|||||||
if (exception_class != nullptr) {
|
if (exception_class != nullptr) {
|
||||||
tensorflow::Safe_PyObjectPtr val(
|
tensorflow::Safe_PyObjectPtr val(
|
||||||
Py_BuildValue("si", msg, TF_GetCode(status)));
|
Py_BuildValue("si", msg, TF_GetCode(status)));
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
// NOTE: This hides the actual error (i.e. the reason `status` was not
|
||||||
|
// TF_OK), but there is nothing we can do at this point since we can't
|
||||||
|
// generate a reasonable error from the status.
|
||||||
|
// Consider adding a message explaining this.
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
PyErr_SetObject(exception_class, val.get());
|
PyErr_SetObject(exception_class, val.get());
|
||||||
return -1;
|
return -1;
|
||||||
} else {
|
} else {
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
|
|
||||||
|
|
||||||
def _create_tensor(value, device=None, dtype=None):
|
def _create_tensor(value, device=None, dtype=None):
|
||||||
@ -242,6 +244,12 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
RuntimeError, "Can't copy Tensor with type string to device"):
|
RuntimeError, "Can't copy Tensor with type string to device"):
|
||||||
_create_tensor("test string")
|
_create_tensor("test string")
|
||||||
|
|
||||||
|
def testInvalidUTF8ProducesReasonableError(self):
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
self.skipTest("Test is only valid in python3.")
|
||||||
|
with self.assertRaises(UnicodeDecodeError):
|
||||||
|
io_ops.read_file(b"\xff")
|
||||||
|
|
||||||
|
|
||||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user