Check handle_data vector size before inspecting first element in list_ops shape inference functions. The current behavior can segfault.

PiperOrigin-RevId: 314348644
Change-Id: I16e2acde203b7f2ad505cb1b89dda892ec660ee9
This commit is contained in:
Saurabh Saxena 2020-06-02 09:47:13 -07:00 committed by TensorFlower Gardener
parent 0b4ae9dda2
commit 8cc97997b7

View File

@ -42,6 +42,11 @@ Status VerifyHandleData(
return Status::OK();
}
bool IsValidTensorListHandleData(
const std::vector<shape_inference::ShapeAndType>* handle_data) {
return handle_data != nullptr && handle_data->size() == 1;
}
// Assumes that the handle_data is valid.
shape_inference::ShapeHandle GetElementShapeFromHandleData(
const std::vector<shape_inference::ShapeAndType>& shapes_and_types) {
@ -83,7 +88,7 @@ REGISTER_OP("TensorListPushBack")
return errors::InvalidArgument(
"Trying to push to list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
@ -130,7 +135,7 @@ REGISTER_OP("TensorListPushBackBatch")
return errors::InvalidArgument(
"Trying to push to list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
@ -171,7 +176,7 @@ REGISTER_OP("TensorListPopBack")
return errors::InvalidArgument(
"Trying to read from list with invalid variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
@ -208,7 +213,7 @@ REGISTER_OP("TensorListStack")
return errors::InvalidArgument(
"Trying to read from list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
@ -252,7 +257,7 @@ Status TensorListConcatShapeInference(
return errors::InvalidArgument(
"Trying to read from list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument(
@ -374,7 +379,7 @@ REGISTER_OP("TensorListElementShape")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr) {
if (!IsValidTensorListHandleData(handle_data)) {
c->set_output(0, c->Vector(c->UnknownDim()));
return Status::OK();
}
@ -412,7 +417,7 @@ REGISTER_OP("TensorListGetItem")
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
shape_inference::ShapeHandle element_shape = c->UnknownShape();
if (handle_data != nullptr) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
element_shape = list_shape_type.shape;
@ -443,7 +448,7 @@ REGISTER_OP("TensorListResize")
TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused));
c->set_output(0, c->Scalar());
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
if (IsValidTensorListHandleData(handle_data)) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
@ -460,16 +465,17 @@ REGISTER_OP("TensorListSetItem")
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
c->set_output(0, c->Scalar());
if (handle_data == nullptr) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
shape_inference::ShapeHandle item_shape = c->input(2);
TF_RETURN_IF_ERROR(
c->Merge(item_shape, list_shape_type.shape, &item_shape));
c->set_output_handle_shapes_and_types(0, *handle_data);
} else {
c->set_output_handle_shapes_and_types(
0, {{c->UnknownShape(), element_dtype}});
return Status::OK();
}
const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
shape_inference::ShapeHandle item_shape = c->input(2);
TF_RETURN_IF_ERROR(
c->Merge(item_shape, list_shape_type.shape, &item_shape));
c->set_output_handle_shapes_and_types(0, *handle_data);
return Status::OK();
});
@ -484,7 +490,7 @@ REGISTER_OP("TensorListGather")
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
shape_inference::ShapeHandle element_shape = c->UnknownShape();
if (handle_data != nullptr) {
if (IsValidTensorListHandleData(handle_data)) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
element_shape = list_shape_type.shape;
@ -563,7 +569,7 @@ REGISTER_OP("TensorListScatterIntoExistingList")
shape_inference::ShapeHandle element_shape = c->UnknownShape();
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
if (IsValidTensorListHandleData(handle_data)) {
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
element_shape = GetElementShapeFromHandleData(*handle_data);
}