Convert TF_SavedModel to a direct pointer to tensorflow::SavedModelAPI. This saves us an extra allocation when loading a savedmodel, and extra indirection on all saved model functions.

PiperOrigin-RevId: 312570488
Change-Id: I16f21a0124af269f6d2b0e1065fbd1aa6a4224b2
This commit is contained in:
Brian Zhao 2020-05-20 15:46:49 -07:00 committed by TensorFlower Gardener
parent 4148ee2e95
commit 6e4fdec80e
3 changed files with 19 additions and 9 deletions
tensorflow/c/experimental/saved_model/internal

View File

@ -155,6 +155,7 @@ cc_library(
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)

View File

@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
void TF_DeleteSavedModel(TF_SavedModel* model) {
delete tensorflow::unwrap(model);
}
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
const char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
tensorflow::unwrap(model)->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
&result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
}
} // end extern "C"

View File

@ -18,13 +18,18 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
typedef struct TF_SavedModel TF_SavedModel;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_