Add TF_DeleteLibraryHandle, for freeing the memory allocated by TF_LoadLibrary.
This is useful when calling the C API from C++, to avoid memory leaks being reported. Change: 132937873
This commit is contained in:
parent
6de77d7f29
commit
56b003d724
@ -7,6 +7,7 @@ load(
|
|||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_cuda_library",
|
"tf_cuda_library",
|
||||||
|
"tf_custom_op_library",
|
||||||
)
|
)
|
||||||
|
|
||||||
# For platform specific build config
|
# For platform specific build config
|
||||||
@ -60,6 +61,7 @@ tf_cc_test(
|
|||||||
name = "c_api_test",
|
name = "c_api_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["c_api_test.cc"],
|
srcs = ["c_api_test.cc"],
|
||||||
|
data = [":test_op.so"],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
@ -81,6 +83,11 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "test_op.so",
|
||||||
|
srcs = ["test_op.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Google-internal targets.
|
# Google-internal targets.
|
||||||
|
|
||||||
|
@ -636,6 +636,11 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
|||||||
|
|
||||||
TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
|
TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
|
||||||
|
|
||||||
|
void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
|
||||||
|
free(const_cast<void*>(lib_handle->op_list.data));
|
||||||
|
delete lib_handle;
|
||||||
|
}
|
||||||
|
|
||||||
TF_Buffer* TF_GetAllOpList() {
|
TF_Buffer* TF_GetAllOpList() {
|
||||||
std::vector<tensorflow::OpDef> op_defs;
|
std::vector<tensorflow::OpDef> op_defs;
|
||||||
tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
|
tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
|
||||||
|
@ -958,6 +958,10 @@ extern TF_Library* TF_LoadLibrary(const char* library_filename,
|
|||||||
// ops defined in the library.
|
// ops defined in the library.
|
||||||
extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
|
extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
|
||||||
|
|
||||||
|
// Frees the memory associated with the library handle.
|
||||||
|
// Does NOT unload the library.
|
||||||
|
extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
|
||||||
|
|
||||||
// Get the OpList of all OpDefs defined in this address space.
|
// Get the OpList of all OpDefs defined in this address space.
|
||||||
// Returns a TF_Buffer, ownership of which is transferred to the caller
|
// Returns a TF_Buffer, ownership of which is transferred to the caller
|
||||||
// (and can be freed using TF_DeleteBuffer).
|
// (and can be freed using TF_DeleteBuffer).
|
||||||
|
@ -91,6 +91,26 @@ TEST(CApi, AllocateTensor) {
|
|||||||
TF_DeleteTensor(t);
|
TF_DeleteTensor(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CApi, LibraryLoadFunctions) {
|
||||||
|
// Load the library.
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_Library* lib =
|
||||||
|
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||||
|
TF_Code code = TF_GetCode(status);
|
||||||
|
string status_msg(TF_Message(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||||
|
|
||||||
|
// Test op list.
|
||||||
|
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||||
|
tensorflow::OpList op_list;
|
||||||
|
EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
|
||||||
|
ASSERT_EQ(op_list.op_size(), 1);
|
||||||
|
EXPECT_EQ("TestCApi", op_list.op(0).name());
|
||||||
|
|
||||||
|
TF_DeleteLibraryHandle(lib);
|
||||||
|
}
|
||||||
|
|
||||||
static void TestEncodeDecode(int line, const std::vector<string>& data) {
|
static void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||||
const tensorflow::int64 n = data.size();
|
const tensorflow::int64 n = data.size();
|
||||||
for (const std::vector<tensorflow::int64>& dims :
|
for (const std::vector<tensorflow::int64>& dims :
|
||||||
|
23
tensorflow/c/test_op.cc
Normal file
23
tensorflow/c/test_op.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_OP("TestCApi").Doc(R"doc(Used to test C API)doc");
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -85,7 +85,7 @@ Status LoadLibrary(const char* library_filename, void** result,
|
|||||||
}
|
}
|
||||||
string str;
|
string str;
|
||||||
op_list.SerializeToString(&str);
|
op_list.SerializeToString(&str);
|
||||||
char* str_buf = reinterpret_cast<char*>(operator new(str.length()));
|
char* str_buf = reinterpret_cast<char*>(malloc(str.length()));
|
||||||
memcpy(str_buf, str.data(), str.length());
|
memcpy(str_buf, str.data(), str.length());
|
||||||
*buf = str_buf;
|
*buf = str_buf;
|
||||||
*len = str.length();
|
*len = str.length();
|
||||||
|
Loading…
Reference in New Issue
Block a user