Fix a memory bug in the C shape inference API.
PiperOrigin-RevId: 262171100
This commit is contained in:
parent
6c04357909
commit
c6719f2091
@ -1092,6 +1092,10 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
std::vector<Tensor> all_input_tensors;
|
||||
// Update the vector with information from `input_tensors` if provided.
|
||||
if (input_tensors != nullptr) {
|
||||
// Note that we take the address of the elements in `all_input_tensors`
|
||||
// below. Allocate enough space so that no reallocation happens, which will
|
||||
// make the pointers invalid.
|
||||
all_input_tensors.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (input_tensors[i] == nullptr) continue;
|
||||
all_input_tensors.emplace_back();
|
||||
|
@ -566,13 +566,19 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
|
||||
|
||||
float five = 5.0;
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
|
||||
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CheckOutputShapes(fill_op,
|
||||
/* input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/* input_tensors*/ {tensor_1X1X6, nullptr},
|
||||
/* input_tensors*/ {tensor_1X1X6, scalarTensor},
|
||||
/*expected_shape*/ make_shape({1, 1, 6}));
|
||||
TFE_DeleteOp(fill_op);
|
||||
fill_op = nullptr;
|
||||
|
||||
TFE_DeleteTensorHandle(scalar);
|
||||
TF_DeleteTensor(scalarTensor);
|
||||
TF_DeleteTensor(tensor_1X1X6);
|
||||
TF_DeleteTensor(tensor_1X6);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user