Raise an exception when converting lists with invalid lengths to Tensors instead of CHECK failing
PiperOrigin-RevId: 177324815
This commit is contained in:
parent
2229a6cbbe
commit
7921d01ec8
@ -899,6 +899,8 @@ set (pywrap_tensorflow_internal_src
|
|||||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h"
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc"
|
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h"
|
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h"
|
||||||
|
@ -268,6 +268,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":ndarray_tensor_bridge",
|
":ndarray_tensor_bridge",
|
||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
|
":py_util",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -309,6 +310,7 @@ cc_library(
|
|||||||
hdrs = ["lib/core/py_seq_tensor.h"],
|
hdrs = ["lib/core/py_seq_tensor.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
|
":py_util",
|
||||||
":safe_ptr",
|
":safe_ptr",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -316,6 +318,17 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "py_util",
|
||||||
|
srcs = ["lib/core/py_util.cc"],
|
||||||
|
hdrs = ["lib/core/py_util.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:script_ops_op_lib",
|
||||||
|
"//util/python:python_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "py_record_reader_lib",
|
name = "py_record_reader_lib",
|
||||||
srcs = ["lib/io/py_record_reader.cc"],
|
srcs = ["lib/io/py_record_reader.cc"],
|
||||||
|
@ -237,6 +237,39 @@ class ConstantTest(test.TestCase):
|
|||||||
self._testAll((1, x))
|
self._testAll((1, x))
|
||||||
self._testAll((x, 1))
|
self._testAll((x, 1))
|
||||||
|
|
||||||
|
def testInvalidLength(self):
|
||||||
|
|
||||||
|
class BadList(list):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(BadList, self).__init__([1, 2, 3]) # pylint: disable=invalid-length-returned
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
constant_op.constant([BadList()])
|
||||||
|
with self.assertRaisesRegexp(ValueError, "mixed types"):
|
||||||
|
constant_op.constant([1, 2, BadList()])
|
||||||
|
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
constant_op.constant(BadList())
|
||||||
|
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
constant_op.constant([[BadList(), 2], 3])
|
||||||
|
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
constant_op.constant([BadList(), [1, 2, 3]])
|
||||||
|
with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
constant_op.constant([BadList(), []])
|
||||||
|
|
||||||
|
# TODO(allenl, josh11b): These cases should return exceptions rather than
|
||||||
|
# working (currently shape checking only checks the first element of each
|
||||||
|
# sequence recursively). Maybe the first one is fine, but the second one
|
||||||
|
# silently truncating is rather bad.
|
||||||
|
|
||||||
|
# with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
# constant_op.constant([[3, 2, 1], BadList()])
|
||||||
|
# with self.assertRaisesRegexp(ValueError, "should return >= 0"):
|
||||||
|
# constant_op.constant([[], BadList()])
|
||||||
|
|
||||||
def testSparseValuesRaiseErrors(self):
|
def testSparseValuesRaiseErrors(self):
|
||||||
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
|
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
|
||||||
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
|
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
|
||||||
|
@ -22,11 +22,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
|
#include "tensorflow/python/lib/core/py_util.h"
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -133,48 +133,6 @@ bool IsSingleNone(PyObject* obj) {
|
|||||||
return item == Py_None;
|
return item == Py_None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// py.__class__.__name__
|
|
||||||
const char* ClassName(PyObject* py) {
|
|
||||||
/* PyPy doesn't have a separate C API for old-style classes. */
|
|
||||||
#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
|
|
||||||
if (PyClass_Check(py))
|
|
||||||
return PyString_AS_STRING(
|
|
||||||
CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
|
|
||||||
if (PyInstance_Check(py))
|
|
||||||
return PyString_AS_STRING(CHECK_NOTNULL(
|
|
||||||
reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
|
|
||||||
#endif
|
|
||||||
if (Py_TYPE(py) == &PyType_Type) {
|
|
||||||
return reinterpret_cast<PyTypeObject*>(py)->tp_name;
|
|
||||||
}
|
|
||||||
return Py_TYPE(py)->tp_name;
|
|
||||||
}
|
|
||||||
|
|
||||||
string PyExcFetch() {
|
|
||||||
CHECK(PyErr_Occurred()) << "Must only call PyExcFetch after an exception.";
|
|
||||||
PyObject* ptype;
|
|
||||||
PyObject* pvalue;
|
|
||||||
PyObject* ptraceback;
|
|
||||||
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
|
|
||||||
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
|
|
||||||
string err = ClassName(ptype);
|
|
||||||
if (pvalue) {
|
|
||||||
PyObject* str = PyObject_Str(pvalue);
|
|
||||||
if (str) {
|
|
||||||
#if PY_MAJOR_VERSION < 3
|
|
||||||
strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
|
|
||||||
#else
|
|
||||||
strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
|
|
||||||
#endif
|
|
||||||
Py_DECREF(str);
|
|
||||||
}
|
|
||||||
Py_DECREF(pvalue);
|
|
||||||
}
|
|
||||||
Py_DECREF(ptype);
|
|
||||||
Py_XDECREF(ptraceback);
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls the registered py function through the trampoline.
|
// Calls the registered py function through the trampoline.
|
||||||
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||||
*out_log_on_error = true;
|
*out_log_on_error = true;
|
||||||
@ -195,18 +153,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
|||||||
if (PyErr_Occurred()) {
|
if (PyErr_Occurred()) {
|
||||||
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
|
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
|
||||||
PyErr_ExceptionMatches(PyExc_TypeError)) {
|
PyErr_ExceptionMatches(PyExc_TypeError)) {
|
||||||
return errors::InvalidArgument(PyExcFetch());
|
return errors::InvalidArgument(PyExceptionFetch());
|
||||||
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
|
||||||
*out_log_on_error = false;
|
*out_log_on_error = false;
|
||||||
return errors::OutOfRange(PyExcFetch());
|
return errors::OutOfRange(PyExceptionFetch());
|
||||||
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
||||||
return errors::ResourceExhausted(PyExcFetch());
|
return errors::ResourceExhausted(PyExceptionFetch());
|
||||||
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
|
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
|
||||||
return errors::Unimplemented(PyExcFetch());
|
return errors::Unimplemented(PyExceptionFetch());
|
||||||
} else {
|
} else {
|
||||||
// TODO(ebrevdo): Check if exception is an OpError and use the
|
// TODO(ebrevdo): Check if exception is an OpError and use the
|
||||||
// OpError.error_code property to map it back in the Status.
|
// OpError.error_code property to map it back in the Status.
|
||||||
return errors::Unknown(PyExcFetch());
|
return errors::Unknown(PyExceptionFetch());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return errors::Internal("Failed to run py callback ", call->token,
|
return errors::Internal("Failed to run py callback ", call->token,
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
#include "tensorflow/python/lib/core/py_util.h"
|
||||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -89,12 +90,25 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
|||||||
*dtype = DT_STRING;
|
*dtype = DT_STRING;
|
||||||
} else if (PySequence_Check(obj)) {
|
} else if (PySequence_Check(obj)) {
|
||||||
auto length = PySequence_Length(obj);
|
auto length = PySequence_Length(obj);
|
||||||
shape->AddDim(length);
|
|
||||||
if (length > 0) {
|
if (length > 0) {
|
||||||
|
shape->AddDim(length);
|
||||||
obj = PySequence_GetItem(obj, 0);
|
obj = PySequence_GetItem(obj, 0);
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else if (length == 0) {
|
||||||
|
shape->AddDim(length);
|
||||||
*dtype = DT_INVALID; // Invalid dtype for empty tensors.
|
*dtype = DT_INVALID; // Invalid dtype for empty tensors.
|
||||||
|
} else {
|
||||||
|
// The sequence does not have a valid length (PySequence_Length < 0).
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
// PySequence_Length failed and set an exception. Fetch the message
|
||||||
|
// and convert it to a failed status.
|
||||||
|
return errors::InvalidArgument(PyExceptionFetch());
|
||||||
|
} else {
|
||||||
|
// This is almost certainly dead code: PySequence_Length failed but
|
||||||
|
// did not set an exception.
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Attempted to convert an invalid sequence to a Tensor.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (IsPyFloat(obj)) {
|
} else if (IsPyFloat(obj)) {
|
||||||
*dtype = DT_DOUBLE;
|
*dtype = DT_DOUBLE;
|
||||||
|
70
tensorflow/python/lib/core/py_util.cc
Normal file
70
tensorflow/python/lib/core/py_util.cc
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/python/lib/core/py_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// py.__class__.__name__
|
||||||
|
const char* ClassName(PyObject* py) {
|
||||||
|
/* PyPy doesn't have a separate C API for old-style classes. */
|
||||||
|
#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
|
||||||
|
if (PyClass_Check(py))
|
||||||
|
return PyString_AS_STRING(
|
||||||
|
CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
|
||||||
|
if (PyInstance_Check(py))
|
||||||
|
return PyString_AS_STRING(CHECK_NOTNULL(
|
||||||
|
reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
|
||||||
|
#endif
|
||||||
|
if (Py_TYPE(py) == &PyType_Type) {
|
||||||
|
return reinterpret_cast<PyTypeObject*>(py)->tp_name;
|
||||||
|
}
|
||||||
|
return Py_TYPE(py)->tp_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
string PyExceptionFetch() {
|
||||||
|
CHECK(PyErr_Occurred())
|
||||||
|
<< "Must only call PyExceptionFetch after an exception.";
|
||||||
|
PyObject* ptype;
|
||||||
|
PyObject* pvalue;
|
||||||
|
PyObject* ptraceback;
|
||||||
|
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
|
||||||
|
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
|
||||||
|
string err = ClassName(ptype);
|
||||||
|
if (pvalue) {
|
||||||
|
PyObject* str = PyObject_Str(pvalue);
|
||||||
|
if (str) {
|
||||||
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
|
||||||
|
#else
|
||||||
|
strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
|
||||||
|
#endif
|
||||||
|
Py_DECREF(str);
|
||||||
|
}
|
||||||
|
Py_DECREF(pvalue);
|
||||||
|
}
|
||||||
|
Py_DECREF(ptype);
|
||||||
|
Py_XDECREF(ptraceback);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
27
tensorflow/python/lib/core/py_util.h
Normal file
27
tensorflow/python/lib/core/py_util.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
|
||||||
|
#define TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
// Fetch the exception message as a string. An exception must be set
|
||||||
|
// (PyErr_Occurred() must be true).
|
||||||
|
string PyExceptionFetch();
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
|
@ -99,7 +99,8 @@ do_pylint() {
|
|||||||
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
|
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
|
||||||
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
|
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
|
||||||
"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
|
||||||
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition"
|
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||||
|
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
|
||||||
|
|
||||||
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user