Convert TF_SavedModel to a direct pointer to tensorflow::SavedModelAPI. This saves us an extra allocation when loading a savedmodel, and extra indirection on all saved model functions.

PiperOrigin-RevId: 312570488
Change-Id: I16f21a0124af269f6d2b0e1065fbd1aa6a4224b2
This commit is contained in:
Brian Zhao 2020-05-20 15:46:49 -07:00 committed by TensorFlower Gardener
parent 4148ee2e95
commit 6e4fdec80e
3 changed files with 19 additions and 9 deletions

View File

@ -155,6 +155,7 @@ cc_library(
"saved_model_api_type.h", "saved_model_api_type.h",
], ],
deps = [ deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:saved_model_api", "//tensorflow/c/experimental/saved_model/core:saved_model_api",
], ],
) )

View File

@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
return new TF_SavedModel{std::move(result)}; return tensorflow::wrap(result.release());
} }
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
return new TF_SavedModel{std::move(result)}; return tensorflow::wrap(result.release());
} }
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } void TF_DeleteSavedModel(TF_SavedModel* model) {
delete tensorflow::unwrap(model);
}
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
const char* function_path, const char* function_path,
TF_Status* status) { TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr; tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status = tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result); tensorflow::unwrap(model)->GetFunction(function_path, &result);
status->status.Update(get_function_status); status->status.Update(get_function_status);
if (!get_function_status.ok()) { if (!get_function_status.ok()) {
return nullptr; return nullptr;
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr; tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status = tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result); tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
&result);
status->status.Update(get_function_status); status->status.Update(get_function_status);
if (!get_function_status.ok()) { if (!get_function_status.ok()) {
return nullptr; return nullptr;
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
} }
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) { TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()}; return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
} }
} // end extern "C" } // end extern "C"

View File

@ -18,13 +18,18 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change // Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on. // and should not be depended on.
struct TF_SavedModel { typedef struct TF_SavedModel TF_SavedModel;
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
}; namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_