From 8cc97997b7bda51d553dd251689023ab53344a28 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Tue, 2 Jun 2020 09:47:13 -0700 Subject: [PATCH] Check handle_data vector size before inspecting first element in list_ops shape inference functions. The current behavior can segfault. PiperOrigin-RevId: 314348644 Change-Id: I16e2acde203b7f2ad505cb1b89dda892ec660ee9 --- tensorflow/core/ops/list_ops.cc | 40 +++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 0a4a30e0309..fde8979fd35 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -42,6 +42,11 @@ Status VerifyHandleData( return Status::OK(); } +bool IsValidTensorListHandleData( + const std::vector* handle_data) { + return handle_data != nullptr && handle_data->size() == 1; +} + // Assumes that the handle_data is valid. shape_inference::ShapeHandle GetElementShapeFromHandleData( const std::vector& shapes_and_types) { @@ -83,7 +88,7 @@ REGISTER_OP("TensorListPushBack") return errors::InvalidArgument( "Trying to push to list with wrong variant data."); } - if (handle_data != nullptr && handle_data->size() == 1) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; if (list_shape_type.dtype != element_dtype) { @@ -130,7 +135,7 @@ REGISTER_OP("TensorListPushBackBatch") return errors::InvalidArgument( "Trying to push to list with wrong variant data."); } - if (handle_data != nullptr && handle_data->size() == 1) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; if (list_shape_type.dtype != element_dtype) { @@ -171,7 +176,7 @@ REGISTER_OP("TensorListPopBack") return errors::InvalidArgument( "Trying to read from list with invalid variant data."); } - if (handle_data != nullptr && handle_data->size() == 1) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; if (list_shape_type.dtype != element_dtype) { @@ -208,7 +213,7 @@ REGISTER_OP("TensorListStack") return errors::InvalidArgument( "Trying to read from list with wrong variant data."); } - if (handle_data != nullptr && handle_data->size() == 1) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; if (list_shape_type.dtype != element_dtype) { @@ -252,7 +257,7 @@ Status TensorListConcatShapeInference( return errors::InvalidArgument( "Trying to read from list with wrong variant data."); } - if (handle_data != nullptr && handle_data->size() == 1) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; if (list_shape_type.dtype != element_dtype) { return errors::InvalidArgument( @@ -374,7 +379,7 @@ REGISTER_OP("TensorListElementShape") .Attr("shape_type: {int32, int64}") .SetShapeFn([](shape_inference::InferenceContext* c) { auto* handle_data = c->input_handle_shapes_and_types(0); - if (handle_data == nullptr) { + if (!IsValidTensorListHandleData(handle_data)) { c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); } @@ -412,7 +417,7 @@ REGISTER_OP("TensorListGetItem") TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype)); auto* handle_data = c->input_handle_shapes_and_types(0); shape_inference::ShapeHandle element_shape = c->UnknownShape(); - if (handle_data != nullptr) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; element_shape = list_shape_type.shape; @@ -443,7 +448,7 @@ REGISTER_OP("TensorListResize") TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused)); c->set_output(0, c->Scalar()); auto* handle_data = c->input_handle_shapes_and_types(0); - if (handle_data != nullptr) { + if (IsValidTensorListHandleData(handle_data)) { c->set_output_handle_shapes_and_types(0, *handle_data); } return Status::OK(); @@ -460,16 +465,17 @@ REGISTER_OP("TensorListSetItem") TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype)); auto* handle_data = c->input_handle_shapes_and_types(0); c->set_output(0, c->Scalar()); - if (handle_data == nullptr) { + if (IsValidTensorListHandleData(handle_data)) { + const shape_inference::ShapeAndType& list_shape_type = + (*handle_data)[0]; + shape_inference::ShapeHandle item_shape = c->input(2); + TF_RETURN_IF_ERROR( + c->Merge(item_shape, list_shape_type.shape, &item_shape)); + c->set_output_handle_shapes_and_types(0, *handle_data); + } else { c->set_output_handle_shapes_and_types( 0, {{c->UnknownShape(), element_dtype}}); - return Status::OK(); } - const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; - shape_inference::ShapeHandle item_shape = c->input(2); - TF_RETURN_IF_ERROR( - c->Merge(item_shape, list_shape_type.shape, &item_shape)); - c->set_output_handle_shapes_and_types(0, *handle_data); return Status::OK(); }); @@ -484,7 +490,7 @@ REGISTER_OP("TensorListGather") TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype)); auto* handle_data = c->input_handle_shapes_and_types(0); shape_inference::ShapeHandle element_shape = c->UnknownShape(); - if (handle_data != nullptr) { + if (IsValidTensorListHandleData(handle_data)) { const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; element_shape = list_shape_type.shape; @@ -563,7 +569,7 @@ REGISTER_OP("TensorListScatterIntoExistingList") shape_inference::ShapeHandle element_shape = c->UnknownShape(); auto* handle_data = c->input_handle_shapes_and_types(0); - if (handle_data != nullptr) { + if (IsValidTensorListHandleData(handle_data)) { TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype)); element_shape = GetElementShapeFromHandleData(*handle_data); }