Raise an exception when converting lists with invalid lengths to Tensors instead of CHECK failing
PiperOrigin-RevId: 177324815
This commit is contained in:
parent
2229a6cbbe
commit
7921d01ec8
@ -899,6 +899,8 @@ set (pywrap_tensorflow_internal_src
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h"
|
||||
|
@ -268,6 +268,7 @@ cc_library(
|
||||
deps = [
|
||||
":ndarray_tensor_bridge",
|
||||
":numpy_lib",
|
||||
":py_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -309,6 +310,7 @@ cc_library(
|
||||
hdrs = ["lib/core/py_seq_tensor.h"],
|
||||
deps = [
|
||||
":numpy_lib",
|
||||
":py_util",
|
||||
":safe_ptr",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -316,6 +318,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_util",
|
||||
srcs = ["lib/core/py_util.cc"],
|
||||
hdrs = ["lib/core/py_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:script_ops_op_lib",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_record_reader_lib",
|
||||
srcs = ["lib/io/py_record_reader.cc"],
|
||||
|
@ -237,6 +237,39 @@ class ConstantTest(test.TestCase):
|
||||
self._testAll((1, x))
|
||||
self._testAll((x, 1))
|
||||
|
||||
def testInvalidLength(self):
|
||||
|
||||
class BadList(list):
|
||||
|
||||
def __init__(self):
|
||||
super(BadList, self).__init__([1, 2, 3]) # pylint: disable=invalid-length-returned
|
||||
|
||||
def __len__(self):
|
||||
return -1
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
constant_op.constant([BadList()])
|
||||
with self.assertRaisesRegexp(ValueError, "mixed types"):
|
||||
constant_op.constant([1, 2, BadList()])
|
||||
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
constant_op.constant(BadList())
|
||||
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
constant_op.constant([[BadList(), 2], 3])
|
||||
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
constant_op.constant([BadList(), [1, 2, 3]])
|
||||
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
constant_op.constant([BadList(), []])
|
||||
|
||||
# TODO(allenl, josh11b): These cases should return exceptions rather than
|
||||
# working (currently shape checking only checks the first element of each
|
||||
# sequence recursively). Maybe the first one is fine, but the second one
|
||||
# silently truncating is rather bad.
|
||||
|
||||
# with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
# constant_op.constant([[3, 2, 1], BadList()])
|
||||
# with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||
# constant_op.constant([[], BadList()])
|
||||
|
||||
def testSparseValuesRaiseErrors(self):
|
||||
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
|
||||
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
|
||||
|
@ -22,11 +22,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/py_util.h"
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorflow {
|
||||
@ -133,48 +133,6 @@ bool IsSingleNone(PyObject* obj) {
|
||||
return item == Py_None;
|
||||
}
|
||||
|
||||
// py.__class__.__name__
|
||||
const char* ClassName(PyObject* py) {
|
||||
/* PyPy doesn't have a separate C API for old-style classes. */
|
||||
#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
|
||||
if (PyClass_Check(py))
|
||||
return PyString_AS_STRING(
|
||||
CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
|
||||
if (PyInstance_Check(py))
|
||||
return PyString_AS_STRING(CHECK_NOTNULL(
|
||||
reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
|
||||
#endif
|
||||
if (Py_TYPE(py) == &PyType_Type) {
|
||||
return reinterpret_cast<PyTypeObject*>(py)->tp_name;
|
||||
}
|
||||
return Py_TYPE(py)->tp_name;
|
||||
}
|
||||
|
||||
string PyExcFetch() {
|
||||
CHECK(PyErr_Occurred()) << "Must only call PyExcFetch after an exception.";
|
||||
PyObject* ptype;
|
||||
PyObject* pvalue;
|
||||
PyObject* ptraceback;
|
||||
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
|
||||
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
|
||||
string err = ClassName(ptype);
|
||||
if (pvalue) {
|
||||
PyObject* str = PyObject_Str(pvalue);
|
||||
if (str) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
|
||||
#else
|
||||
strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
|
||||
#endif
|
||||
Py_DECREF(str);
|
||||
}
|
||||
Py_DECREF(pvalue);
|
||||
}
|
||||
Py_DECREF(ptype);
|
||||
Py_XDECREF(ptraceback);
|
||||
return err;
|
||||
}
|
||||
|
||||
// Calls the registered py function through the trampoline.
|
||||
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||
*out_log_on_error = true;
|
||||
@ -195,18 +153,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||
if (PyErr_Occurred()) {
|
||||
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
|
||||
PyErr_ExceptionMatches(PyExc_TypeError)) {
|
||||
return errors::InvalidArgument(PyExcFetch());
|
||||
return errors::InvalidArgument(PyExceptionFetch());
|
||||
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
||||
*out_log_on_error = false;
|
||||
return errors::OutOfRange(PyExcFetch());
|
||||
return errors::OutOfRange(PyExceptionFetch());
|
||||
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
||||
return errors::ResourceExhausted(PyExcFetch());
|
||||
return errors::ResourceExhausted(PyExceptionFetch());
|
||||
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
|
||||
return errors::Unimplemented(PyExcFetch());
|
||||
return errors::Unimplemented(PyExceptionFetch());
|
||||
} else {
|
||||
// TODO(ebrevdo): Check if exception is an OpError and use the
|
||||
// OpError.error_code property to map it back in the Status.
|
||||
return errors::Unknown(PyExcFetch());
|
||||
return errors::Unknown(PyExceptionFetch());
|
||||
}
|
||||
} else {
|
||||
return errors::Internal("Failed to run py callback ", call->token,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
#include "tensorflow/python/lib/core/py_util.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -89,12 +90,25 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
||||
*dtype = DT_STRING;
|
||||
} else if (PySequence_Check(obj)) {
|
||||
auto length = PySequence_Length(obj);
|
||||
shape->AddDim(length);
|
||||
if (length > 0) {
|
||||
shape->AddDim(length);
|
||||
obj = PySequence_GetItem(obj, 0);
|
||||
continue;
|
||||
} else {
|
||||
} else if (length == 0) {
|
||||
shape->AddDim(length);
|
||||
*dtype = DT_INVALID; // Invalid dtype for empty tensors.
|
||||
} else {
|
||||
// The sequence does not have a valid length (PySequence_Length < 0).
|
||||
if (PyErr_Occurred()) {
|
||||
// PySequence_Length failed and set an exception. Fetch the message
|
||||
// and convert it to a failed status.
|
||||
return errors::InvalidArgument(PyExceptionFetch());
|
||||
} else {
|
||||
// This is almost certainly dead code: PySequence_Length failed but
|
||||
// did not set an exception.
|
||||
return errors::InvalidArgument(
|
||||
"Attempted to convert an invalid sequence to a Tensor.");
|
||||
}
|
||||
}
|
||||
} else if (IsPyFloat(obj)) {
|
||||
*dtype = DT_DOUBLE;
|
||||
|
70
tensorflow/python/lib/core/py_util.cc
Normal file
70
tensorflow/python/lib/core/py_util.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/lib/core/py_util.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// py.__class__.__name__
|
||||
const char* ClassName(PyObject* py) {
|
||||
/* PyPy doesn't have a separate C API for old-style classes. */
|
||||
#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
|
||||
if (PyClass_Check(py))
|
||||
return PyString_AS_STRING(
|
||||
CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
|
||||
if (PyInstance_Check(py))
|
||||
return PyString_AS_STRING(CHECK_NOTNULL(
|
||||
reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
|
||||
#endif
|
||||
if (Py_TYPE(py) == &PyType_Type) {
|
||||
return reinterpret_cast<PyTypeObject*>(py)->tp_name;
|
||||
}
|
||||
return Py_TYPE(py)->tp_name;
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
||||
string PyExceptionFetch() {
|
||||
CHECK(PyErr_Occurred())
|
||||
<< "Must only call PyExceptionFetch after an exception.";
|
||||
PyObject* ptype;
|
||||
PyObject* pvalue;
|
||||
PyObject* ptraceback;
|
||||
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
|
||||
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
|
||||
string err = ClassName(ptype);
|
||||
if (pvalue) {
|
||||
PyObject* str = PyObject_Str(pvalue);
|
||||
if (str) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
|
||||
#else
|
||||
strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
|
||||
#endif
|
||||
Py_DECREF(str);
|
||||
}
|
||||
Py_DECREF(pvalue);
|
||||
}
|
||||
Py_DECREF(ptype);
|
||||
Py_XDECREF(ptraceback);
|
||||
return err;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
27
tensorflow/python/lib/core/py_util.h
Normal file
27
tensorflow/python/lib/core/py_util.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
|
||||
#define TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Fetch the exception message as a string. An exception must be set
|
||||
// (PyErr_Occurred() must be true).
|
||||
string PyExceptionFetch();
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
|
@ -99,7 +99,8 @@ do_pylint() {
|
||||
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
|
||||
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
|
||||
"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
||||
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition"
|
||||
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
|
||||
|
||||
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user