diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index af1c0ea6833..f04f0175696 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -1092,6 +1092,10 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, std::vector all_input_tensors; // Update the vector with information from `input_tensors` if provided. if (input_tensors != nullptr) { + // Note that we take the address of the elements in `all_input_tensors` + // below. Allocate enough space so that no reallocation happens, which will + // make the pointers invalid. + all_input_tensors.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { if (input_tensors[i] == nullptr) continue; all_input_tensors.emplace_back(); diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 4b49b90e293..ed0ab7c26f8 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -566,13 +566,19 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) { TFE_OpSetAttrType(fill_op, "T", TF_FLOAT); TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32); + float five = 5.0; + TFE_TensorHandle* scalar = TestScalarTensorHandle(five); + TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); CheckOutputShapes(fill_op, /* input_shapes*/ {unknown_shape(), unknown_shape()}, - /* input_tensors*/ {tensor_1X1X6, nullptr}, + /* input_tensors*/ {tensor_1X1X6, scalarTensor}, /*expected_shape*/ make_shape({1, 1, 6})); TFE_DeleteOp(fill_op); fill_op = nullptr; + TFE_DeleteTensorHandle(scalar); + TF_DeleteTensor(scalarTensor); TF_DeleteTensor(tensor_1X1X6); TF_DeleteTensor(tensor_1X6); }