Raise an exception when converting lists with invalid lengths to Tensors instead of CHECK failing

PiperOrigin-RevId: 177324815
This commit is contained in:
Allen Lavoie 2017-11-29 10:06:59 -08:00 committed by TensorFlower Gardener
parent 2229a6cbbe
commit 7921d01ec8
8 changed files with 169 additions and 51 deletions

View File

@ -899,6 +899,8 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h"

View File

@ -268,6 +268,7 @@ cc_library(
deps = [
":ndarray_tensor_bridge",
":numpy_lib",
":py_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -309,6 +310,7 @@ cc_library(
hdrs = ["lib/core/py_seq_tensor.h"],
deps = [
":numpy_lib",
":py_util",
":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -316,6 +318,17 @@ cc_library(
],
)
cc_library(
name = "py_util",
srcs = ["lib/core/py_util.cc"],
hdrs = ["lib/core/py_util.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:script_ops_op_lib",
"//util/python:python_headers",
],
)
cc_library(
name = "py_record_reader_lib",
srcs = ["lib/io/py_record_reader.cc"],

View File

@ -237,6 +237,39 @@ class ConstantTest(test.TestCase):
self._testAll((1, x))
self._testAll((x, 1))
def testInvalidLength(self):
class BadList(list):
def __init__(self):
super(BadList, self).__init__([1, 2, 3]) # pylint: disable=invalid-length-returned
def __len__(self):
return -1
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
constant_op.constant([BadList()])
with self.assertRaisesRegexp(ValueError, "mixed types"):
constant_op.constant([1, 2, BadList()])
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
constant_op.constant(BadList())
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
constant_op.constant([[BadList(), 2], 3])
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
constant_op.constant([BadList(), [1, 2, 3]])
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
constant_op.constant([BadList(), []])
# TODO(allenl, josh11b): These cases should return exceptions rather than
# working (currently shape checking only checks the first element of each
# sequence recursively). Maybe the first one is fine, but the second one
# silently truncating is rather bad.
# with self.assertRaisesRegexp(ValueError, "should return >= 0"):
# constant_op.constant([[3, 2, 1], BadList()])
# with self.assertRaisesRegexp(ValueError, "should return >= 0"):
# constant_op.constant([[], BadList()])
def testSparseValuesRaiseErrors(self):
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h"
#include <Python.h>
namespace tensorflow {
@ -133,48 +133,6 @@ bool IsSingleNone(PyObject* obj) {
return item == Py_None;
}
// py.__class__.__name__
const char* ClassName(PyObject* py) {
/* PyPy doesn't have a separate C API for old-style classes. */
#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
if (PyClass_Check(py))
return PyString_AS_STRING(
CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
if (PyInstance_Check(py))
return PyString_AS_STRING(CHECK_NOTNULL(
reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
#endif
if (Py_TYPE(py) == &PyType_Type) {
return reinterpret_cast<PyTypeObject*>(py)->tp_name;
}
return Py_TYPE(py)->tp_name;
}
string PyExcFetch() {
CHECK(PyErr_Occurred()) << "Must only call PyExcFetch after an exception.";
PyObject* ptype;
PyObject* pvalue;
PyObject* ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
string err = ClassName(ptype);
if (pvalue) {
PyObject* str = PyObject_Str(pvalue);
if (str) {
#if PY_MAJOR_VERSION < 3
strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
#else
strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
#endif
Py_DECREF(str);
}
Py_DECREF(pvalue);
}
Py_DECREF(ptype);
Py_XDECREF(ptraceback);
return err;
}
// Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
*out_log_on_error = true;
@ -195,18 +153,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
PyErr_ExceptionMatches(PyExc_TypeError)) {
return errors::InvalidArgument(PyExcFetch());
return errors::InvalidArgument(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
*out_log_on_error = false;
return errors::OutOfRange(PyExcFetch());
return errors::OutOfRange(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
return errors::ResourceExhausted(PyExcFetch());
return errors::ResourceExhausted(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
return errors::Unimplemented(PyExcFetch());
return errors::Unimplemented(PyExceptionFetch());
} else {
// TODO(ebrevdo): Check if exception is an OpError and use the
// OpError.error_code property to map it back in the Status.
return errors::Unknown(PyExcFetch());
return errors::Unknown(PyExceptionFetch());
}
} else {
return errors::Internal("Failed to run py callback ", call->token,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
@ -89,12 +90,25 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
*dtype = DT_STRING;
} else if (PySequence_Check(obj)) {
auto length = PySequence_Length(obj);
shape->AddDim(length);
if (length > 0) {
shape->AddDim(length);
obj = PySequence_GetItem(obj, 0);
continue;
} else {
} else if (length == 0) {
shape->AddDim(length);
*dtype = DT_INVALID; // Invalid dtype for empty tensors.
} else {
// The sequence does not have a valid length (PySequence_Length < 0).
if (PyErr_Occurred()) {
// PySequence_Length failed and set an exception. Fetch the message
// and convert it to a failed status.
return errors::InvalidArgument(PyExceptionFetch());
} else {
// This is almost certainly dead code: PySequence_Length failed but
// did not set an exception.
return errors::InvalidArgument(
"Attempted to convert an invalid sequence to a Tensor.");
}
}
} else if (IsPyFloat(obj)) {
*dtype = DT_DOUBLE;

View File

@ -0,0 +1,70 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include <Python.h>
namespace tensorflow {
namespace {
// py.__class__.__name__
const char* ClassName(PyObject* py) {
/* PyPy doesn't have a separate C API for old-style classes. */
#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
if (PyClass_Check(py))
return PyString_AS_STRING(
CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
if (PyInstance_Check(py))
return PyString_AS_STRING(CHECK_NOTNULL(
reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
#endif
if (Py_TYPE(py) == &PyType_Type) {
return reinterpret_cast<PyTypeObject*>(py)->tp_name;
}
return Py_TYPE(py)->tp_name;
}
} // end namespace
string PyExceptionFetch() {
CHECK(PyErr_Occurred())
<< "Must only call PyExceptionFetch after an exception.";
PyObject* ptype;
PyObject* pvalue;
PyObject* ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
string err = ClassName(ptype);
if (pvalue) {
PyObject* str = PyObject_Str(pvalue);
if (str) {
#if PY_MAJOR_VERSION < 3
strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
#else
strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
#endif
Py_DECREF(str);
}
Py_DECREF(pvalue);
}
Py_DECREF(ptype);
Py_XDECREF(ptraceback);
return err;
}
} // end namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
#define TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Fetch the exception message as a string. An exception must be set
// (PyErr_Occurred() must be true).
string PyExceptionFetch();
} // end namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_

View File

@ -99,7 +99,8 @@ do_pylint() {
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition"
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""