Convert TF_SavedModel to a direct pointer to tensorflow::SavedModelAPI. This saves us an extra allocation when loading a savedmodel, and extra indirection on all saved model functions.
PiperOrigin-RevId: 312570488 Change-Id: I16f21a0124af269f6d2b0e1065fbd1aa6a4224b2
This commit is contained in:
parent
4148ee2e95
commit
6e4fdec80e
@ -155,6 +155,7 @@ cc_library(
|
|||||||
"saved_model_api_type.h",
|
"saved_model_api_type.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
|||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new TF_SavedModel{std::move(result)};
|
return tensorflow::wrap(result.release());
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||||
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
|||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return new TF_SavedModel{std::move(result)};
|
return tensorflow::wrap(result.release());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
void TF_DeleteSavedModel(TF_SavedModel* model) {
|
||||||
|
delete tensorflow::unwrap(model);
|
||||||
|
}
|
||||||
|
|
||||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||||
const char* function_path,
|
const char* function_path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::ConcreteFunction* result = nullptr;
|
tensorflow::ConcreteFunction* result = nullptr;
|
||||||
tensorflow::Status get_function_status =
|
tensorflow::Status get_function_status =
|
||||||
model->saved_model->GetFunction(function_path, &result);
|
tensorflow::unwrap(model)->GetFunction(function_path, &result);
|
||||||
status->status.Update(get_function_status);
|
status->status.Update(get_function_status);
|
||||||
if (!get_function_status.ok()) {
|
if (!get_function_status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
|||||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||||
tensorflow::ConcreteFunction* result = nullptr;
|
tensorflow::ConcreteFunction* result = nullptr;
|
||||||
tensorflow::Status get_function_status =
|
tensorflow::Status get_function_status =
|
||||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
|
||||||
|
&result);
|
||||||
status->status.Update(get_function_status);
|
status->status.Update(get_function_status);
|
||||||
if (!get_function_status.ok()) {
|
if (!get_function_status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
return new TF_ConcreteFunctionList{
|
||||||
|
tensorflow::unwrap(model)->ListFunctions()};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end extern "C"
|
} // end extern "C"
|
||||||
|
@ -18,13 +18,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||||
|
|
||||||
// Internal structures used by the SavedModel C API. These are likely to change
|
// Internal structures used by the SavedModel C API. These are likely to change
|
||||||
// and should not be depended on.
|
// and should not be depended on.
|
||||||
|
|
||||||
struct TF_SavedModel {
|
typedef struct TF_SavedModel TF_SavedModel;
|
||||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
|
||||||
};
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user