Add TF_DeleteLibraryHandle, for freeing the memory allocated by TF_LoadLibrary.

This is useful when calling the C API from C++, to avoid memory leaks being
reported.
Change: 132937873
This commit is contained in:
A. Unique TensorFlower 2016-09-12 16:07:46 -08:00 committed by TensorFlower Gardener
parent 6de77d7f29
commit 56b003d724
6 changed files with 60 additions and 1 deletions

View File

@ -7,6 +7,7 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_cuda_library",
"tf_custom_op_library",
)
# For platform specific build config
@ -60,6 +61,7 @@ tf_cc_test(
name = "c_api_test",
size = "small",
srcs = ["c_api_test.cc"],
data = [":test_op.so"],
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
@ -81,6 +83,11 @@ tf_cc_test(
],
)
tf_custom_op_library(
name = "test_op.so",
srcs = ["test_op.cc"],
)
# -----------------------------------------------------------------------------
# Google-internal targets.

View File

@ -636,6 +636,11 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
free(const_cast<void*>(lib_handle->op_list.data));
delete lib_handle;
}
TF_Buffer* TF_GetAllOpList() {
std::vector<tensorflow::OpDef> op_defs;
tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);

View File

@ -958,6 +958,10 @@ extern TF_Library* TF_LoadLibrary(const char* library_filename,
// ops defined in the library.
extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
// Get the OpList of all OpDefs defined in this address space.
// Returns a TF_Buffer, ownership of which is transferred to the caller
// (and can be freed using TF_DeleteBuffer).

View File

@ -91,6 +91,26 @@ TEST(CApi, AllocateTensor) {
TF_DeleteTensor(t);
}
TEST(CApi, LibraryLoadFunctions) {
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
// Test op list.
TF_Buffer op_list_buf = TF_GetOpList(lib);
tensorflow::OpList op_list;
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
ASSERT_EQ(op_list.op_size(), 1);
EXPECT_EQ("TestCApi", op_list.op(0).name());
TF_DeleteLibraryHandle(lib);
}
static void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size();
for (const std::vector<tensorflow::int64>& dims :

23
tensorflow/c/test_op.cc Normal file
View File

@ -0,0 +1,23 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
REGISTER_OP("TestCApi").Doc(R"doc(Used to test C API)doc");
} // namespace tensorflow

View File

@ -85,7 +85,7 @@ Status LoadLibrary(const char* library_filename, void** result,
}
string str;
op_list.SerializeToString(&str);
char* str_buf = reinterpret_cast<char*>(operator new(str.length()));
char* str_buf = reinterpret_cast<char*>(malloc(str.length()));
memcpy(str_buf, str.data(), str.length());
*buf = str_buf;
*len = str.length();