Addressing large number of outputs in a consistent way across our various C
APIs (Eager fastpath, Eager slowpath, graph). Previously in graph mode, we would error out for somewhat the wrong reason but the error message was decent. This CL makes it more precise. In fastpath eager mode, we tried parsing an int64 attr to an int32, failed and unnecesarily hit the slow path. In the slowpath, we used to just try and allocate a huge output number of handles causing a crash. This CL just checks whether the output size is greater than int32 and errors out before trying to do any allocation. Fixes https://github.com/tensorflow/tensorflow/issues/42281 PiperOrigin-RevId: 339150482 Change-Id: I6d9d1541ccfa9881f92a06348567562f79b0a963
This commit is contained in:
parent
d294dd8a20
commit
7bd42cf6ba
@ -447,9 +447,13 @@ Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
|
||||
const int original_size = sig->size();
|
||||
if (!arg_def.number_attr().empty()) {
|
||||
// Same type repeated "repeats" times.
|
||||
int32 repeats = -1;
|
||||
int64 repeats = -1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
|
||||
// We can't handle outputs that are larger than int32 sizes.
|
||||
if (static_cast<int64>(static_cast<int32>(repeats)) != repeats) {
|
||||
return errors::InvalidArgument("Number of outputs is too big: ", repeats);
|
||||
}
|
||||
if (repeats < 0) {
|
||||
return errors::InvalidArgument("Value for number_attr() ", repeats,
|
||||
" < 0");
|
||||
|
@ -448,6 +448,20 @@ TEST(OutputTypesForNode, Simple) {
|
||||
EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
|
||||
}
|
||||
|
||||
TEST(OutputTypesForNode, LargeOutput) {
|
||||
const OpDef op_def = ToOpDef(OpDefBuilder("TestSplitOp")
|
||||
.Input("value: int64")
|
||||
.Output("output: num_split * int64")
|
||||
.Attr("num_split: int >= 1"));
|
||||
int64 num_split = 1000000000000;
|
||||
const NodeDef node_def =
|
||||
ToNodeDef(std::move(NodeDefBuilder("test_split_op", &op_def)
|
||||
.Input(FakeInput())
|
||||
.Attr("num_split", num_split)));
|
||||
DataTypeVector types;
|
||||
EXPECT_FALSE(OutputTypesForNode(node_def, op_def, &types).ok());
|
||||
}
|
||||
|
||||
TEST(OutputTypesForNode_AttrSliceOverload, Simple) {
|
||||
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
|
||||
.Input("a: float")
|
||||
|
@ -3817,10 +3817,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
}
|
||||
}
|
||||
|
||||
int num_retvals = 0;
|
||||
int64_t num_outputs = 0;
|
||||
for (int i = 0; i < op_def->output_arg_size(); i++) {
|
||||
const auto& output_arg = op_def->output_arg(i);
|
||||
int delta = 1;
|
||||
int64_t delta = 1;
|
||||
if (!output_arg.number_attr().empty()) {
|
||||
delta = attr_list_sizes[output_arg.number_attr()];
|
||||
} else if (!output_arg.type_list_attr().empty()) {
|
||||
@ -3831,9 +3831,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
|
||||
"Attributes suggest that the size of an output list is less than 0");
|
||||
return nullptr;
|
||||
}
|
||||
num_retvals += delta;
|
||||
num_outputs += delta;
|
||||
}
|
||||
|
||||
// If number of retvals is larger than int32, we error out.
|
||||
if (static_cast<int64_t>(static_cast<int32_t>(num_outputs)) != num_outputs) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
Printf("Number of outputs is too big: %ld", num_outputs).c_str());
|
||||
return nullptr;
|
||||
}
|
||||
int num_retvals = num_outputs;
|
||||
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
|
||||
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
|
@ -224,6 +224,36 @@ class Tests(test.TestCase):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Split", None,
|
||||
split_dim, value, "num_split", -1)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
def testFastPathExecute_VeryLargeOutputs(self):
|
||||
split_dim = constant_op.constant(0, dtype=dtypes.int32)
|
||||
value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32)
|
||||
ctx = context.context()
|
||||
ctx.ensure_initialized()
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Number of outputs is too big"):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Split", None, split_dim, value,
|
||||
"num_split", 1000000000000)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
def testSlowPathExecute_VeryLargeOutputs(self):
|
||||
split_dim = constant_op.constant(0, dtype=dtypes.int32)
|
||||
value = [0, 1, 2, 3]
|
||||
ctx = context.context()
|
||||
ctx.ensure_initialized()
|
||||
|
||||
with self.assertRaises(core._FallbackException):
|
||||
pywrap_tfe.TFE_Py_FastPathExecute(ctx, "Split", None, split_dim, value,
|
||||
"num_split", 1000000000000)
|
||||
|
||||
value = constant_op.constant(value)
|
||||
attrs = ("num_splits", 1000000000000)
|
||||
with self.assertRaisesRegex(ValueError, "Number of outputs is too big"):
|
||||
pywrap_tfe.TFE_Py_Execute(ctx._handle, None, "Split", [value], attrs,
|
||||
1000000000000)
|
||||
|
||||
@test_util.assert_no_new_tensors
|
||||
@test_util.assert_no_garbage_created
|
||||
def testInvalidNumOutputs(self):
|
||||
|
@ -204,6 +204,13 @@ TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
|
||||
#else
|
||||
long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT
|
||||
#endif
|
||||
// We can't handle more than int32 sizes for number of outputs.
|
||||
if (static_cast<long>(static_cast<int32>(sz)) != sz) { // NOLINT
|
||||
PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
|
||||
"Number of outputs is too big: ", sz)
|
||||
.c_str());
|
||||
throw py::error_already_set();
|
||||
}
|
||||
if (sz > 0) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
|
||||
|
Loading…
Reference in New Issue
Block a user