Fix a memory bug in the C shape inference API.

PiperOrigin-RevId: 262171100
This commit is contained in:
A. Unique TensorFlower 2019-08-07 10:57:19 -07:00 committed by TensorFlower Gardener
parent 6c04357909
commit c6719f2091
2 changed files with 11 additions and 1 deletions

View File

@ -1092,6 +1092,10 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
std::vector<Tensor> all_input_tensors;
// Update the vector with information from `input_tensors` if provided.
if (input_tensors != nullptr) {
// Note that we take the address of the elements in `all_input_tensors`
// below. Allocate enough space so that no reallocation happens, which will
// make the pointers invalid.
all_input_tensors.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
if (input_tensors[i] == nullptr) continue;
all_input_tensors.emplace_back();

View File

@ -566,13 +566,19 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
float five = 5.0;
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CheckOutputShapes(fill_op,
/* input_shapes*/ {unknown_shape(), unknown_shape()},
/* input_tensors*/ {tensor_1X1X6, nullptr},
/* input_tensors*/ {tensor_1X1X6, scalarTensor},
/*expected_shape*/ make_shape({1, 1, 6}));
TFE_DeleteOp(fill_op);
fill_op = nullptr;
TFE_DeleteTensorHandle(scalar);
TF_DeleteTensor(scalarTensor);
TF_DeleteTensor(tensor_1X1X6);
TF_DeleteTensor(tensor_1X6);
}