Check handle_data vector size before inspecting first element in list_ops shape inference functions. The current behavior can segfault.
PiperOrigin-RevId: 314348644 Change-Id: I16e2acde203b7f2ad505cb1b89dda892ec660ee9
This commit is contained in:
parent
0b4ae9dda2
commit
8cc97997b7
@ -42,6 +42,11 @@ Status VerifyHandleData(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsValidTensorListHandleData(
|
||||
const std::vector<shape_inference::ShapeAndType>* handle_data) {
|
||||
return handle_data != nullptr && handle_data->size() == 1;
|
||||
}
|
||||
|
||||
// Assumes that the handle_data is valid.
|
||||
shape_inference::ShapeHandle GetElementShapeFromHandleData(
|
||||
const std::vector<shape_inference::ShapeAndType>& shapes_and_types) {
|
||||
@ -83,7 +88,7 @@ REGISTER_OP("TensorListPushBack")
|
||||
return errors::InvalidArgument(
|
||||
"Trying to push to list with wrong variant data.");
|
||||
}
|
||||
if (handle_data != nullptr && handle_data->size() == 1) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
if (list_shape_type.dtype != element_dtype) {
|
||||
@ -130,7 +135,7 @@ REGISTER_OP("TensorListPushBackBatch")
|
||||
return errors::InvalidArgument(
|
||||
"Trying to push to list with wrong variant data.");
|
||||
}
|
||||
if (handle_data != nullptr && handle_data->size() == 1) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
if (list_shape_type.dtype != element_dtype) {
|
||||
@ -171,7 +176,7 @@ REGISTER_OP("TensorListPopBack")
|
||||
return errors::InvalidArgument(
|
||||
"Trying to read from list with invalid variant data.");
|
||||
}
|
||||
if (handle_data != nullptr && handle_data->size() == 1) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
if (list_shape_type.dtype != element_dtype) {
|
||||
@ -208,7 +213,7 @@ REGISTER_OP("TensorListStack")
|
||||
return errors::InvalidArgument(
|
||||
"Trying to read from list with wrong variant data.");
|
||||
}
|
||||
if (handle_data != nullptr && handle_data->size() == 1) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
if (list_shape_type.dtype != element_dtype) {
|
||||
@ -252,7 +257,7 @@ Status TensorListConcatShapeInference(
|
||||
return errors::InvalidArgument(
|
||||
"Trying to read from list with wrong variant data.");
|
||||
}
|
||||
if (handle_data != nullptr && handle_data->size() == 1) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
|
||||
if (list_shape_type.dtype != element_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
@ -374,7 +379,7 @@ REGISTER_OP("TensorListElementShape")
|
||||
.Attr("shape_type: {int32, int64}")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data == nullptr) {
|
||||
if (!IsValidTensorListHandleData(handle_data)) {
|
||||
c->set_output(0, c->Vector(c->UnknownDim()));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -412,7 +417,7 @@ REGISTER_OP("TensorListGetItem")
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
shape_inference::ShapeHandle element_shape = c->UnknownShape();
|
||||
if (handle_data != nullptr) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
element_shape = list_shape_type.shape;
|
||||
@ -443,7 +448,7 @@ REGISTER_OP("TensorListResize")
|
||||
TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused));
|
||||
c->set_output(0, c->Scalar());
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -460,16 +465,17 @@ REGISTER_OP("TensorListSetItem")
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
c->set_output(0, c->Scalar());
|
||||
if (handle_data == nullptr) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
shape_inference::ShapeHandle item_shape = c->input(2);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(item_shape, list_shape_type.shape, &item_shape));
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
} else {
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, {{c->UnknownShape(), element_dtype}});
|
||||
return Status::OK();
|
||||
}
|
||||
const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
|
||||
shape_inference::ShapeHandle item_shape = c->input(2);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(item_shape, list_shape_type.shape, &item_shape));
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
@ -484,7 +490,7 @@ REGISTER_OP("TensorListGather")
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
shape_inference::ShapeHandle element_shape = c->UnknownShape();
|
||||
if (handle_data != nullptr) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
const shape_inference::ShapeAndType& list_shape_type =
|
||||
(*handle_data)[0];
|
||||
element_shape = list_shape_type.shape;
|
||||
@ -563,7 +569,7 @@ REGISTER_OP("TensorListScatterIntoExistingList")
|
||||
shape_inference::ShapeHandle element_shape = c->UnknownShape();
|
||||
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
if (IsValidTensorListHandleData(handle_data)) {
|
||||
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
|
||||
element_shape = GetElementShapeFromHandleData(*handle_data);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user