diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 50887f70b5b..7f86d345d7a 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -7,6 +7,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", + "tf_custom_op_library", ) # For platform specific build config @@ -60,6 +61,7 @@ tf_cc_test( name = "c_api_test", size = "small", srcs = ["c_api_test.cc"], + data = [":test_op.so"], linkopts = select({ "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -81,6 +83,11 @@ tf_cc_test( ], ) +tf_custom_op_library( + name = "test_op.so", + srcs = ["test_op.cc"], +) + # ----------------------------------------------------------------------------- # Google-internal targets. diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 4d524d8eb56..713fef89c58 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -636,6 +636,11 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } +void TF_DeleteLibraryHandle(TF_Library* lib_handle) { + free(const_cast(lib_handle->op_list.data)); + delete lib_handle; +} + TF_Buffer* TF_GetAllOpList() { std::vector op_defs; tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 6315d203e32..7482e44a8b2 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -958,6 +958,10 @@ extern TF_Library* TF_LoadLibrary(const char* library_filename, // ops defined in the library. extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); +// Frees the memory associated with the library handle. +// Does NOT unload the library. +extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); + // Get the OpList of all OpDefs defined in this address space. // Returns a TF_Buffer, ownership of which is transferred to the caller // (and can be freed using TF_DeleteBuffer). diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 4c2e8f2ae86..bf8083ebe32 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -91,6 +91,26 @@ TEST(CApi, AllocateTensor) { TF_DeleteTensor(t); } +TEST(CApi, LibraryLoadFunctions) { + // Load the library. + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi", op_list.op(0).name()); + + TF_DeleteLibraryHandle(lib); +} + static void TestEncodeDecode(int line, const std::vector& data) { const tensorflow::int64 n = data.size(); for (const std::vector& dims : diff --git a/tensorflow/c/test_op.cc b/tensorflow/c/test_op.cc new file mode 100644 index 00000000000..c1a1f1cdb1f --- /dev/null +++ b/tensorflow/c/test_op.cc @@ -0,0 +1,23 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +REGISTER_OP("TestCApi").Doc(R"doc(Used to test C API)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index 65f42659a57..12a57756a97 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -85,7 +85,7 @@ Status LoadLibrary(const char* library_filename, void** result, } string str; op_list.SerializeToString(&str); - char* str_buf = reinterpret_cast(operator new(str.length())); + char* str_buf = reinterpret_cast(malloc(str.length())); memcpy(str_buf, str.data(), str.length()); *buf = str_buf; *len = str.length();